diff --git a/CHANGES.txt b/CHANGES.txt index 7680b8d841cc..a90144186ae0 100644 --- a/CHANGES.txt +++ b/CHANGES.txt @@ -1,4 +1,25 @@ -4.0 +4.0-alpha2 + * Extract an AbstractCompactionController to allow for custom implementations (CASSANDRA-15286) + * Move chronicle-core version from snapshot to stable, and include carrotsearch in generated pom.xml (CASSANDRA-15321) + * Untangle RepairMessage sub-hierarchy of messages, use new messaging (more) correctly (CASSANDRA-15163) + * Add `allocate_tokens_for_local_replication_factor` option for token allocation (CASSANDRA-15260) + * Add Alibaba Cloud Platform snitch (CASSANDRA-15092) + + +4.0-alpha1 + * Inaccurate exception message with nodetool snapshot (CASSANDRA-15287) + * Fix InternodeOutboundMetrics overloaded bytes/count mixup (CASSANDRA-15186) + * Enhance & reenable RepairTest with compression=off and compression=on (CASSANDRA-15272) + * Improve readability of Table metrics Virtual tables units (CASSANDRA-15194) + * Fix error with non-existent table for nodetool tablehistograms (CASSANDRA-14410) + * Avoid result truncation in decimal operations (CASSANDRA-15232) + * Catch non-IOException in FileUtils.close to make sure that all resources are closed (CASSANDRA-15225) + * Align load column in nodetool status output (CASSANDRA-14787) + * CassandraNetworkAuthorizer uses cached roles info (CASSANDRA-15089) + * Introduce optional timeouts for idle client sessions (CASSANDRA-11097) + * Fix AlterTableStatement dropped type validation order (CASSANDRA-15203) + * Update Netty dependencies to latest, clean up SocketFactory (CASSANDRA-15195) + * Native Transport - Apply noSpamLogger to ConnectionLimitHandler (CASSANDRA-15167) * Reduce heap pressure during compactions (CASSANDRA-14654) * Support building Cassandra with JDK 11 (CASSANDRA-15108) * Use quilt to patch cassandra.in.sh in Debian packaging (CASSANDRA-14710) @@ -358,13 +379,30 @@ * nodetool clearsnapshot requires --all to clear all snapshots (CASSANDRA-13391) * Correctly count range tombstones in traces and tombstone thresholds (CASSANDRA-8527) * cqlshrc.sample uses incorrect option for time formatting (CASSANDRA-14243) + * Multi-version in-JVM dtests (CASSANDRA-14937) + * Allow instance class loaders to be garbage collected for inJVM dtest (CASSANDRA-15170) 3.11.5 + * Make sure user defined compaction transactions are always closed (CASSANDRA-15123) * Fix cassandra-env.sh to use $CASSANDRA_CONF to find cassandra-jaas.config (CASSANDRA-14305) * Fixed nodetool cfstats printing index name twice (CASSANDRA-14903) * Add flag to disable SASI indexes, and warnings on creation (CASSANDRA-14866) Merged from 3.0: + * Fix LegacyLayout RangeTombstoneList IndexOutOfBoundsException when upgrading and RangeTombstone bounds are asymmetric (CASSANDRA-15172) + * Fix NPE when using allocate_tokens_for_keyspace on new DC/rack (CASSANDRA-14952) + * Filter sstables earlier when running cleanup (CASSANDRA-15100) + * Use mean row count instead of mean column count for index selectivity calculation (CASSANDRA-15259) + * Avoid updating unchanged gossip states (CASSANDRA-15097) + * Prevent recreation of previously dropped columns with a different kind (CASSANDRA-14948) + * Prevent client requests from blocking on executor task queue (CASSANDRA-15013) + * Toughen up column drop/recreate type validations (CASSANDRA-15204) + * LegacyLayout should handle paging states that cross a collection column (CASSANDRA-15201) + * Prevent RuntimeException when username or password is empty/null (CASSANDRA-15198) + * Multiget thrift query returns null records after digest mismatch (CASSANDRA-14812) + * Handle paging states serialized with a different version than the session's (CASSANDRA-15176) + * Throw IOE instead of asserting on unsupporter peer versions (CASSANDRA-15066) + * Update token metadata when handling MOVING/REMOVING_TOKEN events (CASSANDRA-15120) * Add ability to customize cassandra log directory using $CASSANDRA_LOG_DIR (CASSANDRA-15090) * Fix assorted gossip races and add related runtime checks (CASSANDRA-15059) * cassandra-stress works with frozen collections: list and set (CASSANDRA-14907) @@ -375,6 +413,7 @@ Merged from 3.0: * Add missing commands to nodetool_completion (CASSANDRA-14916) * Anti-compaction temporarily corrupts sstable state for readers (CASSANDRA-15004) Merged from 2.2: + * Handle exceptions during authentication/authorization (CASSANDRA-15041) * Support cross version messaging in in-jvm upgrade dtests (CASSANDRA-15078) * Fix index summary redistribution cancellation (CASSANDRA-15045) * Refactor Circle CI configuration (CASSANDRA-14806) diff --git a/build.xml b/build.xml index acd3b7adb737..3eb074fa226a 100644 --- a/build.xml +++ b/build.xml @@ -25,7 +25,7 @@ - + @@ -118,7 +118,7 @@ - + @@ -548,7 +548,8 @@ - + + @@ -771,6 +772,7 @@ + - + @@ -1168,7 +1170,6 @@ @@ -1522,6 +1523,7 @@ + @@ -1945,6 +1947,20 @@ + + + + + + + + + + + + + + @@ -1954,10 +1970,24 @@ - + - + + + + + + + + + + + + + diff --git a/conf/cassandra.yaml b/conf/cassandra.yaml index ca854ca65e68..f3e5c7507c08 100644 --- a/conf/cassandra.yaml +++ b/conf/cassandra.yaml @@ -26,15 +26,21 @@ num_tokens: 256 # Triggers automatic allocation of num_tokens tokens for this node. The allocation # algorithm attempts to choose tokens in a way that optimizes replicated load over -# the nodes in the datacenter for the replication strategy used by the specified -# keyspace. +# the nodes in the datacenter for the replica factor. # # The load assigned to each node will be close to proportional to its number of # vnodes. # # Only supported with the Murmur3Partitioner. + +# Replica factor is determined via the replication strategy used by the specified +# keyspace. # allocate_tokens_for_keyspace: KEYSPACE +# Replica factor is explicitly set, regardless of keyspace or datacenter. +# This is the replica factor within the datacenter, like NTS. +# allocate_tokens_for_local_replication_factor: 3 + # initial_token allows you to specify tokens manually. While you can use it with # vnodes (num_tokens > 1, above) -- in which case you should provide a # comma-separated list -- it's primarily used when adding nodes to legacy clusters @@ -683,6 +689,16 @@ native_transport_port: 9042 # The default is true, which means all supported protocols will be honored. native_transport_allow_older_protocols: true +# Controls when idle client connections are closed. Idle connections are ones that had neither reads +# nor writes for a time period. +# +# Clients may implement heartbeats by sending OPTIONS native protocol message after a timeout, which +# will reset idle timeout timer on the server side. To close idle client connections, corresponding +# values for heartbeat intervals have to be set on the client side. +# +# Idle connection timeouts are disabled by default. +# native_transport_idle_timeout_in_ms: 60000 + # The address or interface to bind the native transport server to. # # Set rpc_address OR rpc_interface, not both. @@ -870,6 +886,32 @@ request_timeout_in_ms: 10000 # which picks up the OS default and configure the net.ipv4.tcp_retries2 sysctl to be ~8. # internode_tcp_user_timeout_in_ms = 30000 +# The maximum continuous period a connection may be unwritable in application space +# internode_application_timeout_in_ms = 30000 + +# Global, per-endpoint and per-connection limits imposed on messages queued for delivery to other nodes +# and waiting to be processed on arrival from other nodes in the cluster. These limits are applied to the on-wire +# size of the message being sent or received. +# +# The basic per-link limit is consumed in isolation before any endpoint or global limit is imposed. +# Each node-pair has three links: urgent, small and large. So any given node may have a maximum of +# N*3*(internode_application_send_queue_capacity_in_bytes+internode_application_receive_queue_capacity_in_bytes) +# messages queued without any coordination between them although in practice, with token-aware routing, only RF*tokens +# nodes should need to communicate with significant bandwidth. +# +# The per-endpoint limit is imposed on all messages exceeding the per-link limit, simultaneously with the global limit, +# on all links to or from a single node in the cluster. +# The global limit is imposed on all messages exceeding the per-link limit, simultaneously with the per-endpoint limit, +# on all links to or from any node in the cluster. +# +# internode_application_send_queue_capacity_in_bytes: 4194304 #4MiB +# internode_application_send_queue_reserve_endpoint_capacity_in_bytes: 134217728 #128MiB +# internode_application_send_queue_reserve_global_capacity_in_bytes: 536870912 #512MiB +# internode_application_receive_queue_capacity_in_bytes: 4194304 #4MiB +# internode_application_receive_queue_reserve_endpoint_capacity_in_bytes: 134217728 #128MiB +# internode_application_receive_queue_reserve_global_capacity_in_bytes: 536870912 #512MiB + + # How long before a node logs slow queries. Select queries that take longer than # this timeout to execute, will generate an aggregated log message, so that slow queries # can be identified. Set this value to zero to disable slow query logging. @@ -1305,4 +1347,4 @@ enable_sasi_indexes: false # Enables creation of transiently replicated keyspaces on this node. # Transient replication is experimental and is not recommended for production use. -enable_transient_replication: false \ No newline at end of file +enable_transient_replication: false diff --git a/debian/changelog b/debian/changelog index 273c6bc2ed9e..edd48a8aa0e1 100644 --- a/debian/changelog +++ b/debian/changelog @@ -1,8 +1,14 @@ -cassandra (4.0) UNRELEASED; urgency=medium +cassandra (4.0~alpha2) UNRELEASED; urgency=medium * New release - -- Michael Shuler Wed, 15 Feb 2017 18:20:09 -0600 + -- + +cassandra (4.0~alpha1) unstable; urgency=medium + + * New release + + -- Michael Shuler Tue, 03 Sep 2019 11:51:18 -0500 cassandra (3.10) unstable; urgency=medium diff --git a/debian/patches/cassandra_logdir_fix.diff b/debian/patches/cassandra_logdir_fix.diff index d75553c11a37..85973313324c 100644 --- a/debian/patches/cassandra_logdir_fix.diff +++ b/debian/patches/cassandra_logdir_fix.diff @@ -1,14 +1,14 @@ --- a/bin/cassandra +++ b/bin/cassandra -@@ -171,7 +171,7 @@ - props="$3" - class="$4" - cassandra_parms="-Dlogback.configurationFile=logback.xml" -- cassandra_parms="$cassandra_parms -Dcassandra.logdir=$CASSANDRA_HOME/logs" -+ cassandra_parms="$cassandra_parms -Dcassandra.logdir=/var/log/cassandra" - cassandra_parms="$cassandra_parms -Dcassandra.storagedir=$cassandra_storagedir" +@@ -109,7 +109,7 @@ + fi + + if [ -z "$CASSANDRA_LOG_DIR" ]; then +- CASSANDRA_LOG_DIR=$CASSANDRA_HOME/logs ++ CASSANDRA_LOG_DIR=/var/log/cassandra + fi - if [ "x$pidpath" != "x" ]; then + # Special-case path variables. --- a/conf/cassandra-env.sh +++ b/conf/cassandra-env.sh @@ -93,16 +93,16 @@ diff --git a/doc/native_protocol_v4.spec b/doc/native_protocol_v4.spec index 5e1e01d2dde8..5670241655f8 100644 --- a/doc/native_protocol_v4.spec +++ b/doc/native_protocol_v4.spec @@ -275,6 +275,9 @@ Table of Contents mode. This mode will make all Thrift and Compact Tables to be exposed as if they were CQL Tables. This is optional; if not specified, the option will not be used. + - "THROW_ON_OVERLOAD": In case of server overloaded with too many requests, by default the server puts + back pressure on the client connection. Instead, the server can send an OverloadedException error message back to + the client if this option is set to true. 4.1.2. AUTH_RESPONSE @@ -1185,3 +1188,4 @@ Table of Contents * The returned in the v4 protocol is not compatible with the v3 protocol. In other words, a returned by a node using protocol v4 should not be used to query a node using protocol v3 (and vice-versa). + * Added THROW_ON_OVERLOAD startup option (Section 4.1.1). diff --git a/doc/source/bugs.rst b/doc/source/bugs.rst index bd58a8f171d3..32d676f9d3f4 100644 --- a/doc/source/bugs.rst +++ b/doc/source/bugs.rst @@ -18,7 +18,7 @@ Reporting Bugs ============== If you encounter a problem with Cassandra, the first places to ask for help are the :ref:`user mailing list -` and the ``#cassandra`` :ref:`IRC channel `. +` and the ``cassandra`` :ref:`Slack room `. If, after having asked for help, you suspect that you have found a bug in Cassandra, you should report it by opening a ticket through the `Apache Cassandra JIRA `__. Please provide as much diff --git a/doc/source/contactus.rst b/doc/source/contactus.rst index 8d0f5dd04663..3ed9004ddcfc 100644 --- a/doc/source/contactus.rst +++ b/doc/source/contactus.rst @@ -17,7 +17,7 @@ Contact us ========== -You can get in touch with the Cassandra community either via the mailing lists or the freenode IRC channels. +You can get in touch with the Cassandra community either via the mailing lists or :ref:`Slack rooms `. .. _mailing-lists: @@ -39,15 +39,12 @@ Subscribe by sending an email to the email address in the Subscribe links above. email to confirm your subscription. Make sure to keep the welcome email as it contains instructions on how to unsubscribe. -.. _irc-channels: +.. _slack: -IRC ---- +Slack +----- +To chat with developers or users in real-time, join our rooms on `ASF Slack `__: -To chat with developers or users in real-time, join our channels on `IRC freenode `__. The -following channels are available: - -- ``#cassandra`` - for user questions and general discussions. -- ``#cassandra-dev`` - strictly for questions or discussions related to Cassandra development. -- ``#cassandra-builds`` - results of automated test builds. +- ``cassandra`` - for user questions and general discussions. +- ``cassandra-dev`` - strictly for questions or discussions related to Cassandra development. diff --git a/doc/source/cql/security.rst b/doc/source/cql/security.rst index 4abeb2d1a603..429a1ef0d67d 100644 --- a/doc/source/cql/security.rst +++ b/doc/source/cql/security.rst @@ -148,6 +148,14 @@ status may ``DROP`` another ``SUPERUSER`` role. Attempting to drop a role which does not exist results in an invalid query condition unless the ``IF EXISTS`` option is used. If the option is used and the role does not exist the statement is a no-op. +.. note:: DROP ROLE intentionally does not terminate any open user sessions. Currently connected sessions will remain + connected and will retain the ability to perform any database actions which do not require :ref:`authorization`. + However, if authorization is enabled, :ref:`permissions` of the dropped role are also revoked, + subject to the :ref:`caching options` configured in :ref:`cassandra.yaml`. + Should a dropped role be subsequently recreated and have new :ref:`permissions` or + :ref:`roles` granted to it, any client sessions still connected will acquire the newly granted + permissions and roles. + .. _grant-role-statement: GRANT ROLE diff --git a/doc/source/development/patches.rst b/doc/source/development/patches.rst index e8e50f6a42ef..f3a2cca0f1cf 100644 --- a/doc/source/development/patches.rst +++ b/doc/source/development/patches.rst @@ -33,12 +33,12 @@ As a general rule of thumb: .. hint:: - Not sure what to work? Just pick an issue tagged with the `low hanging fruit label `_ in JIRA, which we use to flag issues that could turn out to be good starter tasks for beginners. + Not sure what to work? Just pick an issue marked as `Low Hanging Fruit `_ Complexity in JIRA, which we use to flag issues that could turn out to be good starter tasks for beginners. Before You Start Coding ======================= -Although contributions are highly appreciated, we do not guarantee that each contribution will become a part of Cassandra. Therefor it's generally a good idea to first get some feedback on the things you plan to work on, especially about any new features or major changes to the code base. You can reach out to other developers on the mailing list or IRC channel listed on our `community page `_. +Although contributions are highly appreciated, we do not guarantee that each contribution will become a part of Cassandra. Therefore it's generally a good idea to first get some feedback on the things you plan to work on, especially about any new features or major changes to the code base. You can reach out to other developers on the mailing list or :ref:`Slack `. You should also * Avoid redundant work by searching for already reported issues in `JIRA `_ @@ -108,7 +108,14 @@ So you've finished coding and the great moment arrives: it's time to submit your a. Attach a patch to JIRA with a single squashed commit in it (per branch), or b. Squash the commits in-place in your branches into one - 6. Include a CHANGES.txt entry (put it at the top of the list), and format the commit message appropriately in your patch ending with the following statement on the last line: ``patch by X; reviewed by Y for CASSANDRA-ZZZZZ`` + 6. Include a CHANGES.txt entry (put it at the top of the list), and format the commit message appropriately in your patch as below. + + :: + + + + patch by ; reviewed by for CASSANDRA-##### + 7. When you're happy with the result, create a patch: :: diff --git a/doc/source/development/release_process.rst b/doc/source/development/release_process.rst index 23bd7abafd0a..0ab6dff1a40c 100644 --- a/doc/source/development/release_process.rst +++ b/doc/source/development/release_process.rst @@ -108,6 +108,15 @@ The next step is to copy and commit these binaries to staging svnpubsub:: svn add cassandra-dist-dev/ svn ci cassandra-dist-dev/ +After committing the binaries to staging, increment the version number in Cassandra on the `cassandra-` + + cd ~/git/cassandra/ + git checkout cassandra- + edit build.xml # update ` ` + edit debian/changelog # add entry for new version + edit CHANGES.txt # add entry for new version + git commit -m "Update version to " build.xml debian/changelog CHANGES.txt + git push Call for a Vote =============== @@ -241,12 +250,11 @@ Fill out the following email template and send to both user and dev mailing list [2]: (NEWS.txt) https://git1-us-west.apache.org/repos/asf?p=cassandra.git;a=blob_plain;f=NEWS.txt;hb= [3]: https://issues.apache.org/jira/browse/CASSANDRA -Update IRC #cassandra topic +Update Slack Cassandra topic --------------------------- -Update #cassandra topic on irc:: - /msg chanserv op #cassandra - /topic #cassandra "cassandra.apache.org | Latest: 3.11.2 (https://goo.gl/M34ZbG) | Stable: 3.0.16 (https://goo.gl/B4Zumg) | Oldstable: 2.2.12 (https://goo.gl/Uf3GVw) | ask, don't ask to ask" +Update topic in ``cassandra`` :ref:`Slack room ` + /topic cassandra.apache.org | Latest releases: 3.11.4, 3.0.18, 2.2.14, 2.1.21 | ask, don't ask to ask Tweet from @Cassandra --------------------- diff --git a/doc/source/operating/compression.rst b/doc/source/operating/compression.rst index 42a057b242d6..b4308b31a3eb 100644 --- a/doc/source/operating/compression.rst +++ b/doc/source/operating/compression.rst @@ -38,7 +38,9 @@ default, three options are relevant: - ``chunk_length_in_kb`` specifies the number of kilobytes of data per compression chunk. The default is 64KB. - ``crc_check_chance`` determines how likely Cassandra is to verify the checksum on each compression chunk during reads. The default is 1.0. -- ``compression_level`` is only applicable for ``ZstdCompressor`` and accepts values between ``-131072`` and ``2``. +- ``compression_level`` is only applicable for ``ZstdCompressor`` and accepts values between ``-131072`` and ``22``. + The lower the level, the faster the speed (at the cost of compression). Values from 20 to 22 are called + "ultra levels" and should be used with caution, as they require more memory. The default is 3. Users can set compression using the following syntax: diff --git a/doc/source/operating/security.rst b/doc/source/operating/security.rst index e229c7fa3e68..c2d8b79b0798 100644 --- a/doc/source/operating/security.rst +++ b/doc/source/operating/security.rst @@ -182,6 +182,8 @@ See also: :ref:`setting-credentials-for-internal-authentication`, :ref:`CREATE R :ref:`ALTER ROLE `, :ref:`ALTER KEYSPACE ` and :ref:`GRANT PERMISSION `, +.. _authorization: + Authorization ^^^^^^^^^^^^^ @@ -233,6 +235,8 @@ The following assumes that authentication has already been enabled via the proce See also: :ref:`GRANT PERMISSION `, `GRANT ALL ` and :ref:`REVOKE PERMISSION ` +.. _auth-caching: + Caching ^^^^^^^ diff --git a/lib/chronicle-core-1.16.3-SNAPSHOT.jar b/lib/chronicle-core-1.16.4.jar similarity index 63% rename from lib/chronicle-core-1.16.3-SNAPSHOT.jar rename to lib/chronicle-core-1.16.4.jar index eae29e4952ee..0275a72c3018 100644 Binary files a/lib/chronicle-core-1.16.3-SNAPSHOT.jar and b/lib/chronicle-core-1.16.4.jar differ diff --git a/lib/licenses/netty-4.1.28.txt b/lib/licenses/netty-4.1.37.txt similarity index 100% rename from lib/licenses/netty-4.1.28.txt rename to lib/licenses/netty-4.1.37.txt diff --git a/lib/licenses/netty-tcnative-2.0.25.txt b/lib/licenses/netty-tcnative-2.0.25.txt new file mode 100644 index 000000000000..261eeb9e9f8b --- /dev/null +++ b/lib/licenses/netty-tcnative-2.0.25.txt @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/lib/netty-all-4.1.28.Final.jar b/lib/netty-all-4.1.37.Final.jar similarity index 56% rename from lib/netty-all-4.1.28.Final.jar rename to lib/netty-all-4.1.37.Final.jar index 058662ecc9f4..93cff04768e6 100644 Binary files a/lib/netty-all-4.1.28.Final.jar and b/lib/netty-all-4.1.37.Final.jar differ diff --git a/lib/netty-tcnative-boringssl-static-2.0.25.Final.jar b/lib/netty-tcnative-boringssl-static-2.0.25.Final.jar new file mode 100644 index 000000000000..954627fb73f6 Binary files /dev/null and b/lib/netty-tcnative-boringssl-static-2.0.25.Final.jar differ diff --git a/redhat/cassandra.spec b/redhat/cassandra.spec index eaf7565922f5..ca5d38e3e7d5 100644 --- a/redhat/cassandra.spec +++ b/redhat/cassandra.spec @@ -8,7 +8,9 @@ %global username cassandra -%define relname apache-cassandra-%{version} +# input of ~alphaN, ~betaN, ~rcN package versions need to retain upstream '-alphaN, etc' version for sources +%define upstream_version %(echo %{version} | sed -r 's/~/-/g') +%define relname apache-cassandra-%{upstream_version} Name: cassandra Version: %{version} @@ -74,14 +76,14 @@ patch -p1 < debian/patches/cassandra_logdir_fix.diff sed -i 's/^# hints_directory:/hints_directory:/' conf/cassandra.yaml # remove batch, powershell, and other files not being installed -rm conf/*.ps1 -rm bin/*.bat -rm bin/*.orig -rm bin/*.ps1 -rm bin/cassandra.in.sh -rm lib/sigar-bin/*winnt* # strip segfaults on dll.. -rm tools/bin/*.bat -rm tools/bin/cassandra.in.sh +rm -f conf/*.ps1 +rm -f bin/*.bat +rm -f bin/*.orig +rm -f bin/*.ps1 +rm -f bin/cassandra.in.sh +rm -f lib/sigar-bin/*winnt* # strip segfaults on dll.. +rm -f tools/bin/*.bat +rm -f tools/bin/cassandra.in.sh # copy default configs cp -pr conf/* %{buildroot}/%{_sysconfdir}/%{username}/default.conf/ @@ -118,10 +120,12 @@ exit 0 %files %defattr(0644,root,root,0755) %doc CHANGES.txt LICENSE.txt README.asc NEWS.txt NOTICE.txt CASSANDRA-14092.txt +%attr(755,root,root) %{_bindir}/auditlogviewer %attr(755,root,root) %{_bindir}/cassandra-stress %attr(755,root,root) %{_bindir}/cqlsh %attr(755,root,root) %{_bindir}/cqlsh.py %attr(755,root,root) %{_bindir}/debug-cql +%attr(755,root,root) %{_bindir}/fqltool %attr(755,root,root) %{_bindir}/nodetool %attr(755,root,root) %{_bindir}/sstableloader %attr(755,root,root) %{_bindir}/sstablescrub diff --git a/src/java/org/apache/cassandra/auth/CassandraAuthorizer.java b/src/java/org/apache/cassandra/auth/CassandraAuthorizer.java index 238b5b506e79..37ad60a9e4cd 100644 --- a/src/java/org/apache/cassandra/auth/CassandraAuthorizer.java +++ b/src/java/org/apache/cassandra/auth/CassandraAuthorizer.java @@ -53,7 +53,7 @@ public class CassandraAuthorizer implements IAuthorizer private static final String RESOURCE = "resource"; private static final String PERMISSIONS = "permissions"; - SelectStatement authorizeRoleStatement; + private SelectStatement authorizeRoleStatement; public CassandraAuthorizer() { @@ -63,16 +63,24 @@ public CassandraAuthorizer() // or indirectly via roles granted to the user. public Set authorize(AuthenticatedUser user, IResource resource) { - if (user.isSuper()) - return resource.applicablePermissions(); + try + { + if (user.isSuper()) + return resource.applicablePermissions(); - Set permissions = EnumSet.noneOf(Permission.class); + Set permissions = EnumSet.noneOf(Permission.class); - // Even though we only care about the RoleResource here, we use getRoleDetails as - // it saves a Set creation in RolesCache - for (Role role: user.getRoleDetails()) - addPermissionsForRole(permissions, resource, role.resource); - return permissions; + // Even though we only care about the RoleResource here, we use getRoleDetails as + // it saves a Set creation in RolesCache + for (Role role: user.getRoleDetails()) + addPermissionsForRole(permissions, resource, role.resource); + return permissions; + } + catch (RequestExecutionException | RequestValidationException e) + { + logger.debug("Failed to authorize {} for {}", user, resource); + throw new UnauthorizedException("Unable to perform authorization of permissions: " + e.getMessage(), e); + } } public void grant(AuthenticatedUser performer, Set permissions, IResource resource, RoleResource grantee) diff --git a/src/java/org/apache/cassandra/auth/CassandraNetworkAuthorizer.java b/src/java/org/apache/cassandra/auth/CassandraNetworkAuthorizer.java index 34a01402683b..6fdcd6959c9d 100644 --- a/src/java/org/apache/cassandra/auth/CassandraNetworkAuthorizer.java +++ b/src/java/org/apache/cassandra/auth/CassandraNetworkAuthorizer.java @@ -78,7 +78,7 @@ private Set getAuthorizedDcs(String name) public DCPermissions authorize(RoleResource role) { - if (!DatabaseDescriptor.getRoleManager().canLogin(role)) + if (!Roles.canLogin(role)) { return DCPermissions.none(); } diff --git a/src/java/org/apache/cassandra/auth/CassandraRoleManager.java b/src/java/org/apache/cassandra/auth/CassandraRoleManager.java index ebb7d5ff99d4..f4c942859245 100644 --- a/src/java/org/apache/cassandra/auth/CassandraRoleManager.java +++ b/src/java/org/apache/cassandra/auth/CassandraRoleManager.java @@ -270,12 +270,28 @@ public Set getAllRoles() throws RequestValidationException, Reques public boolean isSuper(RoleResource role) { - return getRole(role.getRoleName()).isSuper; + try + { + return getRole(role.getRoleName()).isSuper; + } + catch (RequestExecutionException e) + { + logger.debug("Failed to authorize {} for super-user permission", role.getRoleName()); + throw new UnauthorizedException("Unable to perform authorization of super-user permission: " + e.getMessage(), e); + } } public boolean canLogin(RoleResource role) { - return getRole(role.getRoleName()).canLogin; + try + { + return getRole(role.getRoleName()).canLogin; + } + catch (RequestExecutionException e) + { + logger.debug("Failed to authorize {} for login permission", role.getRoleName()); + throw new UnauthorizedException("Unable to perform authorization of login permission: " + e.getMessage(), e); + } } public Map getCustomOptions(RoleResource role) diff --git a/src/java/org/apache/cassandra/auth/IAuthenticator.java b/src/java/org/apache/cassandra/auth/IAuthenticator.java index 212e77495052..80ea719237b9 100644 --- a/src/java/org/apache/cassandra/auth/IAuthenticator.java +++ b/src/java/org/apache/cassandra/auth/IAuthenticator.java @@ -105,7 +105,7 @@ default SaslNegotiator newSaslNegotiator(InetAddress clientAddress, X509Certific public interface SaslNegotiator { /** - * Evaluates the client response data and generates a byte[] reply which may be a further challenge or purely + * Evaluates the client response data and generates a byte[] response which may be a further challenge or purely * informational in the case that the negotiation is completed on this round. * * This method is called each time a {@link org.apache.cassandra.transport.messages.AuthResponse} is received diff --git a/src/java/org/apache/cassandra/auth/PasswordAuthenticator.java b/src/java/org/apache/cassandra/auth/PasswordAuthenticator.java index 27a68a02c043..9da99a9a6083 100644 --- a/src/java/org/apache/cassandra/auth/PasswordAuthenticator.java +++ b/src/java/org/apache/cassandra/auth/PasswordAuthenticator.java @@ -29,6 +29,7 @@ import org.slf4j.LoggerFactory; import org.apache.cassandra.config.DatabaseDescriptor; +import org.apache.cassandra.exceptions.RequestExecutionException; import org.apache.cassandra.schema.SchemaConstants; import org.apache.cassandra.cql3.QueryOptions; import org.apache.cassandra.cql3.QueryProcessor; @@ -64,7 +65,7 @@ public class PasswordAuthenticator implements IAuthenticator public static final String USERNAME_KEY = "username"; public static final String PASSWORD_KEY = "password"; - private static final byte NUL = 0; + static final byte NUL = 0; private SelectStatement authenticateStatement; private CredentialsCache cache; @@ -100,23 +101,30 @@ private AuthenticatedUser authenticate(String username, String password) throws private String queryHashedPassword(String username) { - ResultMessage.Rows rows = - authenticateStatement.execute(QueryState.forInternalCalls(), - QueryOptions.forInternalCalls(consistencyForRole(username), - Lists.newArrayList(ByteBufferUtil.bytes(username))), - System.nanoTime()); - - // If either a non-existent role name was supplied, or no credentials - // were found for that role we don't want to cache the result so we throw - // a specific, but unchecked, exception to keep LoadingCache happy. - if (rows.result.isEmpty()) - throw new AuthenticationException(String.format("Provided username %s and/or password are incorrect", username)); - - UntypedResultSet result = UntypedResultSet.create(rows.result); - if (!result.one().has(SALTED_HASH)) - throw new AuthenticationException(String.format("Provided username %s and/or password are incorrect", username)); - - return result.one().getString(SALTED_HASH); + try + { + ResultMessage.Rows rows = + authenticateStatement.execute(QueryState.forInternalCalls(), + QueryOptions.forInternalCalls(consistencyForRole(username), + Lists.newArrayList(ByteBufferUtil.bytes(username))), + System.nanoTime()); + + // If either a non-existent role name was supplied, or no credentials + // were found for that role we don't want to cache the result so we throw + // an exception. + if (rows.result.isEmpty()) + throw new AuthenticationException(String.format("Provided username %s and/or password are incorrect", username)); + + UntypedResultSet result = UntypedResultSet.create(rows.result); + if (!result.one().has(SALTED_HASH)) + throw new AuthenticationException(String.format("Provided username %s and/or password are incorrect", username)); + + return result.one().getString(SALTED_HASH); + } + catch (RequestExecutionException e) + { + throw new AuthenticationException("Unable to perform authentication: " + e.getMessage(), e); + } } public Set protectedResources() @@ -206,7 +214,7 @@ private void decodeCredentials(byte[] bytes) throws AuthenticationException byte[] user = null; byte[] pass = null; int end = bytes.length; - for (int i = bytes.length - 1 ; i >= 0; i--) + for (int i = bytes.length - 1; i >= 0; i--) { if (bytes[i] == NUL) { @@ -214,13 +222,16 @@ private void decodeCredentials(byte[] bytes) throws AuthenticationException pass = Arrays.copyOfRange(bytes, i + 1, end); else if (user == null) user = Arrays.copyOfRange(bytes, i + 1, end); + else + throw new AuthenticationException("Credential format error: username or password is empty or contains NUL(\\0) character"); + end = i; } } - if (pass == null) + if (pass == null || pass.length == 0) throw new AuthenticationException("Password must not be null"); - if (user == null) + if (user == null || user.length == 0) throw new AuthenticationException("Authentication ID must not be null"); username = new String(user, StandardCharsets.UTF_8); diff --git a/src/java/org/apache/cassandra/auth/Roles.java b/src/java/org/apache/cassandra/auth/Roles.java index 22eb0d31e82a..527451e703c0 100644 --- a/src/java/org/apache/cassandra/auth/Roles.java +++ b/src/java/org/apache/cassandra/auth/Roles.java @@ -25,10 +25,17 @@ import com.google.common.annotations.VisibleForTesting; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + import org.apache.cassandra.config.DatabaseDescriptor; +import org.apache.cassandra.exceptions.RequestExecutionException; +import org.apache.cassandra.exceptions.UnauthorizedException; public class Roles { + private static final Logger logger = LoggerFactory.getLogger(Roles.class); + private static final Role NO_ROLE = new Role("", false, false, Collections.emptyMap(), Collections.emptySet()); private static RolesCache cache; @@ -91,11 +98,19 @@ public static Set getRoleDetails(RoleResource primaryRole) */ public static boolean hasSuperuserStatus(RoleResource role) { - for (Role r : getRoleDetails(role)) - if (r.isSuper) - return true; - - return false; + try + { + for (Role r : getRoleDetails(role)) + if (r.isSuper) + return true; + + return false; + } + catch (RequestExecutionException e) + { + logger.debug("Failed to authorize {} for super-user permission", role.getRoleName()); + throw new UnauthorizedException("Unable to perform authorization of super-user permission: " + e.getMessage(), e); + } } /** @@ -106,11 +121,19 @@ public static boolean hasSuperuserStatus(RoleResource role) */ public static boolean canLogin(final RoleResource role) { - for (Role r : getRoleDetails(role)) - if (r.resource.equals(role)) - return r.canLogin; - - return false; + try + { + for (Role r : getRoleDetails(role)) + if (r.resource.equals(role)) + return r.canLogin; + + return false; + } + catch (RequestExecutionException e) + { + logger.debug("Failed to authorize {} for login permission", role.getRoleName()); + throw new UnauthorizedException("Unable to perform authorization of login permission: " + e.getMessage(), e); + } } /** diff --git a/src/java/org/apache/cassandra/auth/jmx/AuthorizationProxy.java b/src/java/org/apache/cassandra/auth/jmx/AuthorizationProxy.java index ef00027604b9..b213c43d66c2 100644 --- a/src/java/org/apache/cassandra/auth/jmx/AuthorizationProxy.java +++ b/src/java/org/apache/cassandra/auth/jmx/AuthorizationProxy.java @@ -33,7 +33,6 @@ import javax.security.auth.Subject; import com.google.common.annotations.VisibleForTesting; -import com.google.common.base.Throwables; import com.google.common.collect.ImmutableSet; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -491,10 +490,5 @@ protected JMXPermissionsCache() AuthorizationProxy::loadPermissions, () -> true); } - - public Set get(RoleResource roleResource) - { - return super.get(roleResource); - } } } diff --git a/src/java/org/apache/cassandra/batchlog/BatchRemoveVerbHandler.java b/src/java/org/apache/cassandra/batchlog/BatchRemoveVerbHandler.java index 3c3fcec49096..3443cab78393 100644 --- a/src/java/org/apache/cassandra/batchlog/BatchRemoveVerbHandler.java +++ b/src/java/org/apache/cassandra/batchlog/BatchRemoveVerbHandler.java @@ -20,11 +20,13 @@ import java.util.UUID; import org.apache.cassandra.net.IVerbHandler; -import org.apache.cassandra.net.MessageIn; +import org.apache.cassandra.net.Message; public final class BatchRemoveVerbHandler implements IVerbHandler { - public void doVerb(MessageIn message, int id) + public static final BatchRemoveVerbHandler instance = new BatchRemoveVerbHandler(); + + public void doVerb(Message message) { BatchlogManager.remove(message.payload); } diff --git a/src/java/org/apache/cassandra/batchlog/BatchStoreVerbHandler.java b/src/java/org/apache/cassandra/batchlog/BatchStoreVerbHandler.java index 4bc878cbf592..77335cb44389 100644 --- a/src/java/org/apache/cassandra/batchlog/BatchStoreVerbHandler.java +++ b/src/java/org/apache/cassandra/batchlog/BatchStoreVerbHandler.java @@ -17,16 +17,17 @@ */ package org.apache.cassandra.batchlog; -import org.apache.cassandra.db.WriteResponse; import org.apache.cassandra.net.IVerbHandler; -import org.apache.cassandra.net.MessageIn; +import org.apache.cassandra.net.Message; import org.apache.cassandra.net.MessagingService; public final class BatchStoreVerbHandler implements IVerbHandler { - public void doVerb(MessageIn message, int id) + public static final BatchStoreVerbHandler instance = new BatchStoreVerbHandler(); + + public void doVerb(Message message) { BatchlogManager.store(message.payload); - MessagingService.instance().sendReply(WriteResponse.createMessage(), id, message.from); + MessagingService.instance().send(message.emptyResponse(), message.from()); } } diff --git a/src/java/org/apache/cassandra/batchlog/BatchlogManager.java b/src/java/org/apache/cassandra/batchlog/BatchlogManager.java index b2b851df89d0..f140332f509f 100644 --- a/src/java/org/apache/cassandra/batchlog/BatchlogManager.java +++ b/src/java/org/apache/cassandra/batchlog/BatchlogManager.java @@ -32,6 +32,7 @@ import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledThreadPoolExecutor; import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.Collections2; @@ -65,20 +66,23 @@ import org.apache.cassandra.locator.ReplicaLayout; import org.apache.cassandra.locator.ReplicaPlan; import org.apache.cassandra.locator.Replicas; -import org.apache.cassandra.net.MessageIn; -import org.apache.cassandra.net.MessageOut; +import org.apache.cassandra.net.Message; +import org.apache.cassandra.net.MessageFlag; import org.apache.cassandra.net.MessagingService; import org.apache.cassandra.schema.SchemaConstants; import org.apache.cassandra.schema.TableId; import org.apache.cassandra.service.StorageService; import org.apache.cassandra.service.WriteResponseHandler; +import org.apache.cassandra.utils.ExecutorUtils; import org.apache.cassandra.utils.FBUtilities; import org.apache.cassandra.utils.MBeanWrapper; import org.apache.cassandra.utils.UUIDGen; import static com.google.common.collect.Iterables.transform; +import static java.util.concurrent.TimeUnit.MILLISECONDS; import static org.apache.cassandra.cql3.QueryProcessor.executeInternal; import static org.apache.cassandra.cql3.QueryProcessor.executeInternalWithPaging; +import static org.apache.cassandra.net.Verb.MUTATION_REQ; public class BatchlogManager implements BatchlogManagerMBean { @@ -88,7 +92,7 @@ public class BatchlogManager implements BatchlogManagerMBean private static final Logger logger = LoggerFactory.getLogger(BatchlogManager.class); public static final BatchlogManager instance = new BatchlogManager(); - public static final long BATCHLOG_REPLAY_TIMEOUT = Long.getLong("cassandra.batchlog.replay_timeout_in_ms", DatabaseDescriptor.getWriteRpcTimeout() * 2); + public static final long BATCHLOG_REPLAY_TIMEOUT = Long.getLong("cassandra.batchlog.replay_timeout_in_ms", DatabaseDescriptor.getWriteRpcTimeout(MILLISECONDS) * 2); private volatile long totalBatchesReplayed = 0; // no concurrency protection necessary as only written by replay thread. private volatile UUID lastReplayedUuid = UUIDGen.minTimeUUID(0); @@ -112,13 +116,12 @@ public void start() batchlogTasks.scheduleWithFixedDelay(this::replayFailedBatches, StorageService.RING_DELAY, REPLAY_INTERVAL, - TimeUnit.MILLISECONDS); + MILLISECONDS); } - public void shutdown() throws InterruptedException + public void shutdownAndWait(long timeout, TimeUnit unit) throws InterruptedException, TimeoutException { - batchlogTasks.shutdown(); - batchlogTasks.awaitTermination(60, TimeUnit.SECONDS); + ExecutorUtils.shutdownAndWait(timeout, unit, batchlogTasks); } public static void remove(UUID id) @@ -356,7 +359,7 @@ public int replay(RateLimiter rateLimiter, Set hintedNodes) return 0; int gcgs = gcgs(mutations); - if (TimeUnit.MILLISECONDS.toSeconds(writtenAt) + gcgs <= FBUtilities.nowInSeconds()) + if (MILLISECONDS.toSeconds(writtenAt) + gcgs <= FBUtilities.nowInSeconds()) return 0; replayHandlers = sendReplays(mutations, writtenAt, hintedNodes); @@ -419,7 +422,7 @@ private void writeHintsForUndeliveredEndpoints(int startFrom, Set sendSingleReplayMutation(fin ReplicaPlan.ForTokenWrite replicaPlan = new ReplicaPlan.ForTokenWrite(keyspace, ConsistencyLevel.ONE, liveRemoteOnly.pending(), liveRemoteOnly.all(), liveRemoteOnly.all(), liveRemoteOnly.all()); ReplayWriteResponseHandler handler = new ReplayWriteResponseHandler<>(replicaPlan, System.nanoTime()); - MessageOut message = mutation.createMessage(); + Message message = Message.outWithFlag(MUTATION_REQ, mutation, MessageFlag.CALL_BACK_ON_FAILURE); for (Replica replica : liveRemoteOnly.all()) - MessagingService.instance().sendWriteRR(message, replica, handler, false); + MessagingService.instance().sendWriteWithCallback(message, replica, handler, false); return handler; } @@ -506,7 +509,7 @@ private static int gcgs(Collection mutations) /** * A wrapper of WriteResponseHandler that stores the addresses of the endpoints from - * which we did not receive a successful reply. + * which we did not receive a successful response. */ private static class ReplayWriteResponseHandler extends WriteResponseHandler { @@ -525,11 +528,11 @@ protected int blockFor() } @Override - public void response(MessageIn m) + public void onResponse(Message m) { - boolean removed = undelivered.remove(m == null ? FBUtilities.getBroadcastAddressAndPort() : m.from); + boolean removed = undelivered.remove(m == null ? FBUtilities.getBroadcastAddressAndPort() : m.from()); assert removed; - super.response(m); + super.onResponse(m); } } } diff --git a/src/java/org/apache/cassandra/concurrent/ImmediateExecutor.java b/src/java/org/apache/cassandra/concurrent/ImmediateExecutor.java new file mode 100644 index 000000000000..1a00e4f3beda --- /dev/null +++ b/src/java/org/apache/cassandra/concurrent/ImmediateExecutor.java @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.concurrent; + +import java.util.Collections; +import java.util.List; +import java.util.concurrent.AbstractExecutorService; +import java.util.concurrent.TimeUnit; + +public class ImmediateExecutor extends AbstractExecutorService implements LocalAwareExecutorService +{ + public static final ImmediateExecutor INSTANCE = new ImmediateExecutor(); + + private ImmediateExecutor() {} + + public void execute(Runnable command, ExecutorLocals locals) + { + command.run(); + } + + public void maybeExecuteImmediately(Runnable command) + { + command.run(); + } + + public void execute(Runnable command) + { + command.run(); + } + + public int getActiveTaskCount() { return 0; } + public long getCompletedTaskCount() { return 0; } + public int getPendingTaskCount() { return 0; } + public int getMaximumPoolSize() { return 0; } + public void shutdown() { } + public List shutdownNow() { return Collections.emptyList(); } + public boolean isShutdown() { return false; } + public boolean isTerminated() { return false; } + public boolean awaitTermination(long timeout, TimeUnit unit) { return true; } +} diff --git a/src/java/org/apache/cassandra/concurrent/InfiniteLoopExecutor.java b/src/java/org/apache/cassandra/concurrent/InfiniteLoopExecutor.java index 199803f04489..b54fa3fca51f 100644 --- a/src/java/org/apache/cassandra/concurrent/InfiniteLoopExecutor.java +++ b/src/java/org/apache/cassandra/concurrent/InfiniteLoopExecutor.java @@ -70,7 +70,7 @@ public InfiniteLoopExecutor start() return this; } - public void shutdown() + public void shutdownNow() { isShutdown = true; thread.interrupt(); diff --git a/src/java/org/apache/cassandra/concurrent/NamedThreadFactory.java b/src/java/org/apache/cassandra/concurrent/NamedThreadFactory.java index 33f1312b9f8c..7cc73bd1c3c4 100644 --- a/src/java/org/apache/cassandra/concurrent/NamedThreadFactory.java +++ b/src/java/org/apache/cassandra/concurrent/NamedThreadFactory.java @@ -24,6 +24,7 @@ import io.netty.util.concurrent.FastThreadLocal; import io.netty.util.concurrent.FastThreadLocalThread; +import org.apache.cassandra.utils.memory.BufferPool; /** * This class is an implementation of the ThreadFactory interface. This @@ -35,6 +36,7 @@ public class NamedThreadFactory implements ThreadFactory { private static volatile String globalPrefix; public static void setGlobalPrefix(String prefix) { globalPrefix = prefix; } + public static String globalPrefix() { return globalPrefix == null ? "" : globalPrefix; } public final String id; private final int priority; diff --git a/src/java/org/apache/cassandra/concurrent/ScheduledExecutors.java b/src/java/org/apache/cassandra/concurrent/ScheduledExecutors.java index 5e3e5cf3964c..c549c4d587c8 100644 --- a/src/java/org/apache/cassandra/concurrent/ScheduledExecutors.java +++ b/src/java/org/apache/cassandra/concurrent/ScheduledExecutors.java @@ -17,10 +17,17 @@ */ package org.apache.cassandra.concurrent; +import java.util.List; import java.util.concurrent.ExecutorService; import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableList; + +import org.apache.cassandra.utils.ExecutorUtils; + +import org.apache.cassandra.utils.ExecutorUtils; /** * Centralized location for shared executors @@ -48,12 +55,8 @@ public class ScheduledExecutors public static final DebuggableScheduledThreadPoolExecutor optionalTasks = new DebuggableScheduledThreadPoolExecutor("OptionalTasks"); @VisibleForTesting - public static void shutdownAndWait() throws InterruptedException + public static void shutdownAndWait(long timeout, TimeUnit unit) throws InterruptedException, TimeoutException { - ExecutorService[] executors = new ExecutorService[] { scheduledFastTasks, scheduledTasks, nonPeriodicTasks, optionalTasks }; - for (ExecutorService executor : executors) - executor.shutdownNow(); - for (ExecutorService executor : executors) - executor.awaitTermination(60, TimeUnit.SECONDS); + ExecutorUtils.shutdownNowAndWait(timeout, unit, scheduledFastTasks, scheduledTasks, nonPeriodicTasks, optionalTasks); } } diff --git a/src/java/org/apache/cassandra/concurrent/SharedExecutorPool.java b/src/java/org/apache/cassandra/concurrent/SharedExecutorPool.java index 62bede9add1a..3388ea495807 100644 --- a/src/java/org/apache/cassandra/concurrent/SharedExecutorPool.java +++ b/src/java/org/apache/cassandra/concurrent/SharedExecutorPool.java @@ -22,6 +22,7 @@ import java.util.concurrent.ConcurrentSkipListMap; import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.locks.LockSupport; @@ -107,14 +108,14 @@ void maybeStartSpinningWorker() schedule(Work.SPINNING); } - public LocalAwareExecutorService newExecutor(int maxConcurrency, int maxQueuedTasks, String jmxPath, String name) + public synchronized LocalAwareExecutorService newExecutor(int maxConcurrency, int maxQueuedTasks, String jmxPath, String name) { SEPExecutor executor = new SEPExecutor(this, maxConcurrency, maxQueuedTasks, jmxPath, name); executors.add(executor); return executor; } - public void shutdown() throws InterruptedException + public synchronized void shutdownAndWait(long timeout, TimeUnit unit) throws InterruptedException, TimeoutException { shuttingDown = true; for (SEPExecutor executor : executors) @@ -122,9 +123,13 @@ public void shutdown() throws InterruptedException terminateWorkers(); - long until = System.nanoTime() + TimeUnit.MINUTES.toNanos(1L); + long until = System.nanoTime() + unit.toNanos(timeout); for (SEPExecutor executor : executors) + { executor.shutdown.await(until - System.nanoTime(), TimeUnit.NANOSECONDS); + if (!executor.isTerminated()) + throw new TimeoutException(executor.name + " not terminated"); + } } void terminateWorkers() diff --git a/src/java/org/apache/cassandra/concurrent/Stage.java b/src/java/org/apache/cassandra/concurrent/Stage.java index ccb156501e40..ed13eebe2b47 100644 --- a/src/java/org/apache/cassandra/concurrent/Stage.java +++ b/src/java/org/apache/cassandra/concurrent/Stage.java @@ -17,11 +17,6 @@ */ package org.apache.cassandra.concurrent; -import java.util.Arrays; - -import com.google.common.base.Predicate; -import com.google.common.collect.Iterables; - public enum Stage { READ, @@ -35,18 +30,7 @@ public enum Stage MISC, TRACING, INTERNAL_RESPONSE, - READ_REPAIR; - - public static Iterable jmxEnabledStages() - { - return Iterables.filter(Arrays.asList(values()), new Predicate() - { - public boolean apply(Stage stage) - { - return stage != TRACING; - } - }); - } + IMMEDIATE; public String getJmxType() { @@ -58,13 +42,13 @@ public String getJmxType() case MISC: case TRACING: case INTERNAL_RESPONSE: + case IMMEDIATE: return "internal"; case MUTATION: case COUNTER_MUTATION: case VIEW_MUTATION: case READ: case REQUEST_RESPONSE: - case READ_REPAIR: return "request"; default: throw new AssertionError("Unknown stage " + this); diff --git a/src/java/org/apache/cassandra/concurrent/StageManager.java b/src/java/org/apache/cassandra/concurrent/StageManager.java index 608a00595216..b0d34aee09bc 100644 --- a/src/java/org/apache/cassandra/concurrent/StageManager.java +++ b/src/java/org/apache/cassandra/concurrent/StageManager.java @@ -17,7 +17,9 @@ */ package org.apache.cassandra.concurrent; +import java.util.Collections; import java.util.EnumMap; +import java.util.List; import java.util.concurrent.*; import com.google.common.annotations.VisibleForTesting; @@ -25,9 +27,12 @@ import org.slf4j.LoggerFactory; import org.apache.cassandra.net.MessagingService; +import org.apache.cassandra.net.Verb; +import org.apache.cassandra.utils.ExecutorUtils; import org.apache.cassandra.utils.FBUtilities; import static org.apache.cassandra.config.DatabaseDescriptor.*; +import static org.apache.cassandra.utils.ExecutorUtils.*; /** @@ -56,24 +61,18 @@ public class StageManager stages.put(Stage.ANTI_ENTROPY, new JMXEnabledThreadPoolExecutor(Stage.ANTI_ENTROPY)); stages.put(Stage.MIGRATION, new JMXEnabledThreadPoolExecutor(Stage.MIGRATION)); stages.put(Stage.MISC, new JMXEnabledThreadPoolExecutor(Stage.MISC)); - stages.put(Stage.READ_REPAIR, multiThreadedStage(Stage.READ_REPAIR, FBUtilities.getAvailableProcessors())); stages.put(Stage.TRACING, tracingExecutor()); + stages.put(Stage.IMMEDIATE, ImmediateExecutor.INSTANCE); } private static LocalAwareExecutorService tracingExecutor() { - RejectedExecutionHandler reh = new RejectedExecutionHandler() - { - public void rejectedExecution(Runnable r, ThreadPoolExecutor executor) - { - MessagingService.instance().incrementDroppedMessages(MessagingService.Verb._TRACE); - } - }; + RejectedExecutionHandler reh = (r, executor) -> MessagingService.instance().metrics.recordSelfDroppedMessage(Verb._TRACE); return new TracingExecutor(1, 1, KEEPALIVE, TimeUnit.SECONDS, - new ArrayBlockingQueue(1000), + new ArrayBlockingQueue<>(1000), new NamedThreadFactory(Stage.TRACING.getJmxName()), reh); } @@ -114,12 +113,9 @@ public static void shutdownNow() } @VisibleForTesting - public static void shutdownAndWait() throws InterruptedException + public static void shutdownAndWait(long timeout, TimeUnit unit) throws InterruptedException, TimeoutException { - for (Stage stage : Stage.values()) - StageManager.stages.get(stage).shutdown(); - for (Stage stage : Stage.values()) - StageManager.stages.get(stage).awaitTermination(60, TimeUnit.SECONDS); + ExecutorUtils.shutdownNowAndWait(timeout, unit, StageManager.stages.values()); } /** @@ -155,4 +151,5 @@ public int getPendingTaskCount() return getQueue().size(); } } + } diff --git a/src/java/org/apache/cassandra/config/Config.java b/src/java/org/apache/cassandra/config/Config.java index a6050bea8baa..b86b7c57e690 100644 --- a/src/java/org/apache/cassandra/config/Config.java +++ b/src/java/org/apache/cassandra/config/Config.java @@ -86,6 +86,10 @@ public class Config public int num_tokens = 1; /** Triggers automatic allocation of tokens if set, using the replication strategy of the referenced keyspace */ public String allocate_tokens_for_keyspace = null; + /** Triggers automatic allocation of tokens if set, based on the provided replica count for a datacenter */ + public Integer allocate_tokens_for_local_replication_factor = null; + + public long native_transport_idle_timeout_in_ms = 0L; public volatile long request_timeout_in_ms = 10000L; @@ -128,6 +132,8 @@ public class Config public volatile Integer repair_session_max_tree_depth = null; public volatile Integer repair_session_space_in_mb = null; + public volatile boolean use_offheap_merkle_trees = true; + public int storage_port = 7000; public int ssl_storage_port = 7001; public String listen_address; @@ -149,8 +155,21 @@ public class Config public boolean rpc_interface_prefer_ipv6 = false; public String broadcast_rpc_address; public boolean rpc_keepalive = true; - public int internode_send_buff_size_in_bytes = 0; - public int internode_recv_buff_size_in_bytes = 0; + + public Integer internode_max_message_size_in_bytes; + + public int internode_socket_send_buffer_size_in_bytes = 0; + public int internode_socket_receive_buffer_size_in_bytes = 0; + + // TODO: derive defaults from system memory settings? + public int internode_application_send_queue_capacity_in_bytes = 1 << 22; // 4MiB + public int internode_application_send_queue_reserve_endpoint_capacity_in_bytes = 1 << 27; // 128MiB + public int internode_application_send_queue_reserve_global_capacity_in_bytes = 1 << 29; // 512MiB + + public int internode_application_receive_queue_capacity_in_bytes = 1 << 22; // 4MiB + public int internode_application_receive_queue_reserve_endpoint_capacity_in_bytes = 1 << 27; // 128MiB + public int internode_application_receive_queue_reserve_global_capacity_in_bytes = 1 << 29; // 512MiB + // Defensive settings for protecting Cassandra from true network partitions. See (CASSANDRA-14358) for details. // The amount of time to wait for internode tcp connections to establish. public int internode_tcp_connect_timeout_in_ms = 2000; @@ -170,6 +189,9 @@ public class Config public boolean native_transport_flush_in_batches_legacy = false; public volatile boolean native_transport_allow_older_protocols = true; public int native_transport_frame_block_size_in_kb = 32; + public volatile long native_transport_max_concurrent_requests_in_bytes_per_ip = -1L; + public volatile long native_transport_max_concurrent_requests_in_bytes = -1L; + /** * Max size of values in SSTables, in MegaBytes. @@ -322,7 +344,7 @@ public class Config public volatile ConsistencyLevel ideal_consistency_level = null; /* - * Strategy to use for coalescing messages in {@link OutboundMessagingPool}. + * Strategy to use for coalescing messages in {@link OutboundConnections}. * Can be fixed, movingaverage, timehorizon, disabled. Setting is case and leading/trailing * whitespace insensitive. You can also specify a subclass of * {@link org.apache.cassandra.utils.CoalescingStrategies.CoalescingStrategy} by name. @@ -339,12 +361,6 @@ public class Config public int otc_coalescing_window_us = otc_coalescing_window_us_default; public int otc_coalescing_enough_coalesced_messages = 8; - /** - * Backlog expiration interval in milliseconds for the OutboundTcpConnection. - */ - public static final int otc_backlog_expiration_interval_ms_default = 200; - public volatile int otc_backlog_expiration_interval_ms = otc_backlog_expiration_interval_ms_default; - public int windows_timer_interval = 0; /** diff --git a/src/java/org/apache/cassandra/config/DatabaseDescriptor.java b/src/java/org/apache/cassandra/config/DatabaseDescriptor.java index b3ab054ce95a..e4ea611b734f 100644 --- a/src/java/org/apache/cassandra/config/DatabaseDescriptor.java +++ b/src/java/org/apache/cassandra/config/DatabaseDescriptor.java @@ -27,6 +27,7 @@ import java.nio.file.Paths; import java.util.*; import java.util.concurrent.TimeUnit; +import java.util.function.Function; import java.util.function.Supplier; import com.google.common.annotations.VisibleForTesting; @@ -72,6 +73,7 @@ import org.apache.commons.lang3.StringUtils; +import static java.util.concurrent.TimeUnit.MILLISECONDS; import static org.apache.cassandra.io.util.FileUtils.ONE_GB; public class DatabaseDescriptor @@ -80,6 +82,7 @@ public class DatabaseDescriptor { // This static block covers most usages FBUtilities.preventIllegalAccessWarnings(); + System.setProperty("io.netty.transport.estimateSizeOnSubmit", "false"); } private static final Logger logger = LoggerFactory.getLogger(DatabaseDescriptor.class); @@ -505,6 +508,16 @@ else if (conf.native_transport_max_frame_size_in_mb >= 2048) conf.hints_directory = storagedirFor("hints"); } + if (conf.native_transport_max_concurrent_requests_in_bytes <= 0) + { + conf.native_transport_max_concurrent_requests_in_bytes = Runtime.getRuntime().maxMemory() / 10; + } + + if (conf.native_transport_max_concurrent_requests_in_bytes_per_ip <= 0) + { + conf.native_transport_max_concurrent_requests_in_bytes_per_ip = Runtime.getRuntime().maxMemory() / 40; + } + if (conf.cdc_raw_directory == null) { conf.cdc_raw_directory = storagedirFor("cdc_raw"); @@ -801,6 +814,28 @@ else if (conf.max_value_size_in_mb >= 2048) if (conf.otc_coalescing_enough_coalesced_messages <= 0) throw new ConfigurationException("otc_coalescing_enough_coalesced_messages must be positive", false); + Integer maxMessageSize = conf.internode_max_message_size_in_bytes; + if (maxMessageSize != null) + { + if (maxMessageSize > conf.internode_application_receive_queue_reserve_endpoint_capacity_in_bytes) + throw new ConfigurationException("internode_max_message_size_in_mb must no exceed internode_application_receive_queue_reserve_endpoint_capacity_in_bytes", false); + + if (maxMessageSize > conf.internode_application_receive_queue_reserve_global_capacity_in_bytes) + throw new ConfigurationException("internode_max_message_size_in_mb must no exceed internode_application_receive_queue_reserve_global_capacity_in_bytes", false); + + if (maxMessageSize > conf.internode_application_send_queue_reserve_endpoint_capacity_in_bytes) + throw new ConfigurationException("internode_max_message_size_in_mb must no exceed internode_application_send_queue_reserve_endpoint_capacity_in_bytes", false); + + if (maxMessageSize > conf.internode_application_send_queue_reserve_global_capacity_in_bytes) + throw new ConfigurationException("internode_max_message_size_in_mb must no exceed internode_application_send_queue_reserve_global_capacity_in_bytes", false); + } + else + { + conf.internode_max_message_size_in_bytes = + Math.min(conf.internode_application_receive_queue_reserve_endpoint_capacity_in_bytes, + conf.internode_application_send_queue_reserve_endpoint_capacity_in_bytes); + } + validateMaxConcurrentAutoUpgradeTasksConf(conf.max_concurrent_automatic_sstable_upgrades); } @@ -1387,6 +1422,11 @@ public static String getAllocateTokensForKeyspace() return System.getProperty(Config.PROPERTY_PREFIX + "allocate_tokens_for_keyspace", conf.allocate_tokens_for_keyspace); } + public static Integer getAllocateTokensForLocalRf() + { + return conf.allocate_tokens_for_local_replication_factor; + } + public static Collection tokensFromString(String tokenString) { List tokens = new ArrayList(); @@ -1448,9 +1488,19 @@ public static int getSSLStoragePort() return Integer.parseInt(System.getProperty(Config.PROPERTY_PREFIX + "ssl_storage_port", Integer.toString(conf.ssl_storage_port))); } - public static long getRpcTimeout() + public static long nativeTransportIdleTimeout() + { + return conf.native_transport_idle_timeout_in_ms; + } + + public static void setNativeTransportIdleTimeout(long nativeTransportTimeout) + { + conf.native_transport_idle_timeout_in_ms = nativeTransportTimeout; + } + + public static long getRpcTimeout(TimeUnit unit) { - return conf.request_timeout_in_ms; + return unit.convert(conf.request_timeout_in_ms, MILLISECONDS); } public static void setRpcTimeout(long timeOutInMillis) @@ -1458,9 +1508,9 @@ public static void setRpcTimeout(long timeOutInMillis) conf.request_timeout_in_ms = timeOutInMillis; } - public static long getReadRpcTimeout() + public static long getReadRpcTimeout(TimeUnit unit) { - return conf.read_request_timeout_in_ms; + return unit.convert(conf.read_request_timeout_in_ms, MILLISECONDS); } public static void setReadRpcTimeout(long timeOutInMillis) @@ -1468,9 +1518,9 @@ public static void setReadRpcTimeout(long timeOutInMillis) conf.read_request_timeout_in_ms = timeOutInMillis; } - public static long getRangeRpcTimeout() + public static long getRangeRpcTimeout(TimeUnit unit) { - return conf.range_request_timeout_in_ms; + return unit.convert(conf.range_request_timeout_in_ms, MILLISECONDS); } public static void setRangeRpcTimeout(long timeOutInMillis) @@ -1478,9 +1528,9 @@ public static void setRangeRpcTimeout(long timeOutInMillis) conf.range_request_timeout_in_ms = timeOutInMillis; } - public static long getWriteRpcTimeout() + public static long getWriteRpcTimeout(TimeUnit unit) { - return conf.write_request_timeout_in_ms; + return unit.convert(conf.write_request_timeout_in_ms, MILLISECONDS); } public static void setWriteRpcTimeout(long timeOutInMillis) @@ -1488,9 +1538,9 @@ public static void setWriteRpcTimeout(long timeOutInMillis) conf.write_request_timeout_in_ms = timeOutInMillis; } - public static long getCounterWriteRpcTimeout() + public static long getCounterWriteRpcTimeout(TimeUnit unit) { - return conf.counter_write_request_timeout_in_ms; + return unit.convert(conf.counter_write_request_timeout_in_ms, MILLISECONDS); } public static void setCounterWriteRpcTimeout(long timeOutInMillis) @@ -1498,9 +1548,9 @@ public static void setCounterWriteRpcTimeout(long timeOutInMillis) conf.counter_write_request_timeout_in_ms = timeOutInMillis; } - public static long getCasContentionTimeout() + public static long getCasContentionTimeout(TimeUnit unit) { - return conf.cas_contention_timeout_in_ms; + return unit.convert(conf.cas_contention_timeout_in_ms, MILLISECONDS); } public static void setCasContentionTimeout(long timeOutInMillis) @@ -1508,9 +1558,9 @@ public static void setCasContentionTimeout(long timeOutInMillis) conf.cas_contention_timeout_in_ms = timeOutInMillis; } - public static long getTruncateRpcTimeout() + public static long getTruncateRpcTimeout(TimeUnit unit) { - return conf.truncate_request_timeout_in_ms; + return unit.convert(conf.truncate_request_timeout_in_ms, MILLISECONDS); } public static void setTruncateRpcTimeout(long timeOutInMillis) @@ -1523,27 +1573,32 @@ public static boolean hasCrossNodeTimeout() return conf.cross_node_timeout; } - public static long getSlowQueryTimeout() + public static void setCrossNodeTimeout(boolean crossNodeTimeout) { - return conf.slow_query_log_timeout_in_ms; + conf.cross_node_timeout = crossNodeTimeout; + } + + public static long getSlowQueryTimeout(TimeUnit units) + { + return units.convert(conf.slow_query_log_timeout_in_ms, MILLISECONDS); } /** * @return the minimum configured {read, write, range, truncate, misc} timeout */ - public static long getMinRpcTimeout() + public static long getMinRpcTimeout(TimeUnit unit) { - return Longs.min(getRpcTimeout(), - getReadRpcTimeout(), - getRangeRpcTimeout(), - getWriteRpcTimeout(), - getCounterWriteRpcTimeout(), - getTruncateRpcTimeout()); + return Longs.min(getRpcTimeout(unit), + getReadRpcTimeout(unit), + getRangeRpcTimeout(unit), + getWriteRpcTimeout(unit), + getCounterWriteRpcTimeout(unit), + getTruncateRpcTimeout(unit)); } - public static long getPingTimeout() + public static long getPingTimeout(TimeUnit unit) { - return TimeUnit.SECONDS.toMillis(getBlockForPeersTimeoutInSeconds()); + return unit.convert(getBlockForPeersTimeoutInSeconds(), TimeUnit.SECONDS); } public static double getPhiConvictThreshold() @@ -1833,14 +1888,44 @@ public static boolean getRpcKeepAlive() return conf.rpc_keepalive; } - public static int getInternodeSendBufferSize() + public static int getInternodeSocketSendBufferSizeInBytes() + { + return conf.internode_socket_send_buffer_size_in_bytes; + } + + public static int getInternodeSocketReceiveBufferSizeInBytes() + { + return conf.internode_socket_receive_buffer_size_in_bytes; + } + + public static int getInternodeApplicationSendQueueCapacityInBytes() + { + return conf.internode_application_send_queue_capacity_in_bytes; + } + + public static int getInternodeApplicationSendQueueReserveEndpointCapacityInBytes() { - return conf.internode_send_buff_size_in_bytes; + return conf.internode_application_send_queue_reserve_endpoint_capacity_in_bytes; } - public static int getInternodeRecvBufferSize() + public static int getInternodeApplicationSendQueueReserveGlobalCapacityInBytes() { - return conf.internode_recv_buff_size_in_bytes; + return conf.internode_application_send_queue_reserve_global_capacity_in_bytes; + } + + public static int getInternodeApplicationReceiveQueueCapacityInBytes() + { + return conf.internode_application_receive_queue_capacity_in_bytes; + } + + public static int getInternodeApplicationReceiveQueueReserveEndpointCapacityInBytes() + { + return conf.internode_application_receive_queue_reserve_endpoint_capacity_in_bytes; + } + + public static int getInternodeApplicationReceiveQueueReserveGlobalCapacityInBytes() + { + return conf.internode_application_receive_queue_reserve_global_capacity_in_bytes; } public static int getInternodeTcpConnectTimeoutInMS() @@ -1863,6 +1948,17 @@ public static void setInternodeTcpUserTimeoutInMS(int value) conf.internode_tcp_user_timeout_in_ms = value; } + public static int getInternodeMaxMessageSizeInBytes() + { + return conf.internode_max_message_size_in_bytes; + } + + @VisibleForTesting + public static void setInternodeMaxMessageSizeInBytes(int value) + { + conf.internode_max_message_size_in_bytes = value; + } + public static boolean startNativeTransport() { return conf.start_native_transport; @@ -1954,6 +2050,26 @@ public static void setCommitLogSyncGroupWindow(double windowMillis) conf.commitlog_sync_group_window_in_ms = windowMillis; } + public static long getNativeTransportMaxConcurrentRequestsInBytesPerIp() + { + return conf.native_transport_max_concurrent_requests_in_bytes_per_ip; + } + + public static void setNativeTransportMaxConcurrentRequestsInBytesPerIp(long maxConcurrentRequestsInBytes) + { + conf.native_transport_max_concurrent_requests_in_bytes_per_ip = maxConcurrentRequestsInBytes; + } + + public static long getNativeTransportMaxConcurrentRequestsInBytes() + { + return conf.native_transport_max_concurrent_requests_in_bytes; + } + + public static void setNativeTransportMaxConcurrentRequestsInBytes(long maxConcurrentRequestsInBytes) + { + conf.native_transport_max_concurrent_requests_in_bytes = maxConcurrentRequestsInBytes; + } + public static int getCommitLogSyncPeriod() { return conf.commitlog_sync_period_in_ms; @@ -2142,6 +2258,12 @@ public static EncryptionOptions getNativeProtocolEncryptionOptions() return conf.client_encryption_options; } + @VisibleForTesting + public static void updateNativeProtocolEncryptionOptions(Function update) + { + conf.client_encryption_options = update.apply(conf.client_encryption_options); + } + public static int getHintedHandoffThrottleInKB() { return conf.hinted_handoff_throttle_in_kb; @@ -2485,41 +2607,6 @@ public static int getTracetypeQueryTTL() return conf.tracetype_query_ttl; } - public static String getOtcCoalescingStrategy() - { - return conf.otc_coalescing_strategy; - } - - public static void setOtcCoalescingStrategy(String strategy) - { - conf.otc_coalescing_strategy = strategy; - } - - public static int getOtcCoalescingWindow() - { - return conf.otc_coalescing_window_us; - } - - public static int getOtcCoalescingEnoughCoalescedMessages() - { - return conf.otc_coalescing_enough_coalesced_messages; - } - - public static void setOtcCoalescingEnoughCoalescedMessages(int otc_coalescing_enough_coalesced_messages) - { - conf.otc_coalescing_enough_coalesced_messages = otc_coalescing_enough_coalesced_messages; - } - - public static int getOtcBacklogExpirationInterval() - { - return conf.otc_backlog_expiration_interval_ms; - } - - public static void setOtcBacklogExpirationInterval(int intervalInMillis) - { - conf.otc_backlog_expiration_interval_ms = intervalInMillis; - } - public static int getWindowsTimerInterval() { return conf.windows_timer_interval; @@ -2825,4 +2912,15 @@ public static boolean strictRuntimeChecks() { return strictRuntimeChecks; } + + public static boolean useOffheapMerkleTrees() + { + return conf.use_offheap_merkle_trees; + } + + public static void useOffheapMerkleTrees(boolean value) + { + logger.info("Setting use_offheap_merkle_trees to {}", value); + conf.use_offheap_merkle_trees = value; + } } diff --git a/src/java/org/apache/cassandra/config/EncryptionOptions.java b/src/java/org/apache/cassandra/config/EncryptionOptions.java index 9524cec36ec0..0a33dcc66778 100644 --- a/src/java/org/apache/cassandra/config/EncryptionOptions.java +++ b/src/java/org/apache/cassandra/config/EncryptionOptions.java @@ -17,30 +17,61 @@ */ package org.apache.cassandra.config; -import java.util.Arrays; +import java.util.List; import java.util.Objects; +import com.google.common.collect.ImmutableList; + +import org.apache.cassandra.locator.IEndpointSnitch; +import org.apache.cassandra.locator.InetAddressAndPort; + public class EncryptionOptions { - public String keystore = "conf/.keystore"; - public String keystore_password = "cassandra"; - public String truststore = "conf/.truststore"; - public String truststore_password = "cassandra"; - public String[] cipher_suites = {}; - public String protocol = "TLS"; - public String algorithm = null; - public String store_type = "JKS"; - public boolean require_client_auth = false; - public boolean require_endpoint_verification = false; - public boolean enabled = false; - public boolean optional = false; + public final String keystore; + public final String keystore_password; + public final String truststore; + public final String truststore_password; + public final List cipher_suites; + public final String protocol; + public final String algorithm; + public final String store_type; + public final boolean require_client_auth; + public final boolean require_endpoint_verification; + public final boolean enabled; + public final boolean optional; public EncryptionOptions() - { } + { + keystore = "conf/.keystore"; + keystore_password = "cassandra"; + truststore = "conf/.truststore"; + truststore_password = "cassandra"; + cipher_suites = ImmutableList.of(); + protocol = "TLS"; + algorithm = null; + store_type = "JKS"; + require_client_auth = false; + require_endpoint_verification = false; + enabled = false; + optional = false; + } + + public EncryptionOptions(String keystore, String keystore_password, String truststore, String truststore_password, List cipher_suites, String protocol, String algorithm, String store_type, boolean require_client_auth, boolean require_endpoint_verification, boolean enabled, boolean optional) + { + this.keystore = keystore; + this.keystore_password = keystore_password; + this.truststore = truststore; + this.truststore_password = truststore_password; + this.cipher_suites = cipher_suites; + this.protocol = protocol; + this.algorithm = algorithm; + this.store_type = store_type; + this.require_client_auth = require_client_auth; + this.require_endpoint_verification = require_endpoint_verification; + this.enabled = enabled; + this.optional = optional; + } - /** - * Copy constructor - */ public EncryptionOptions(EncryptionOptions options) { keystore = options.keystore; @@ -57,6 +88,97 @@ public EncryptionOptions(EncryptionOptions options) optional = options.optional; } + public EncryptionOptions withKeyStore(String keystore) + { + return new EncryptionOptions(keystore, keystore_password, truststore, truststore_password, cipher_suites, + protocol, algorithm, store_type, require_client_auth, require_endpoint_verification, + enabled, optional); + } + + public EncryptionOptions withKeyStorePassword(String keystore_password) + { + return new EncryptionOptions(keystore, keystore_password, truststore, truststore_password, cipher_suites, + protocol, algorithm, store_type, require_client_auth, require_endpoint_verification, + enabled, optional); + } + + public EncryptionOptions withTrustStore(String truststore) + { + return new EncryptionOptions(keystore, keystore_password, truststore, truststore_password, cipher_suites, + protocol, algorithm, store_type, require_client_auth, require_endpoint_verification, + enabled, optional); + } + + public EncryptionOptions withTrustStorePassword(String truststore_password) + { + return new EncryptionOptions(keystore, keystore_password, truststore, truststore_password, cipher_suites, + protocol, algorithm, store_type, require_client_auth, require_endpoint_verification, + enabled, optional); + } + + public EncryptionOptions withCipherSuites(List cipher_suites) + { + return new EncryptionOptions(keystore, keystore_password, truststore, truststore_password, cipher_suites, + protocol, algorithm, store_type, require_client_auth, require_endpoint_verification, + enabled, optional); + } + + public EncryptionOptions withCipherSuites(String ... cipher_suites) + { + return new EncryptionOptions(keystore, keystore_password, truststore, truststore_password, ImmutableList.copyOf(cipher_suites), + protocol, algorithm, store_type, require_client_auth, require_endpoint_verification, + enabled, optional); + } + + public EncryptionOptions withProtocol(String protocol) + { + return new EncryptionOptions(keystore, keystore_password, truststore, truststore_password, cipher_suites, + protocol, algorithm, store_type, require_client_auth, require_endpoint_verification, + enabled, optional); + } + + public EncryptionOptions withAlgorithm(String algorithm) + { + return new EncryptionOptions(keystore, keystore_password, truststore, truststore_password, cipher_suites, + protocol, algorithm, store_type, require_client_auth, require_endpoint_verification, + enabled, optional); + } + + public EncryptionOptions withStoreType(String store_type) + { + return new EncryptionOptions(keystore, keystore_password, truststore, truststore_password, cipher_suites, + protocol, algorithm, store_type, require_client_auth, require_endpoint_verification, + enabled, optional); + } + + public EncryptionOptions withRequireClientAuth(boolean require_client_auth) + { + return new EncryptionOptions(keystore, keystore_password, truststore, truststore_password, cipher_suites, + protocol, algorithm, store_type, require_client_auth, require_endpoint_verification, + enabled, optional); + } + + public EncryptionOptions withRequireEndpointVerification(boolean require_endpoint_verification) + { + return new EncryptionOptions(keystore, keystore_password, truststore, truststore_password, cipher_suites, + protocol, algorithm, store_type, require_client_auth, require_endpoint_verification, + enabled, optional); + } + + public EncryptionOptions withEnabled(boolean enabled) + { + return new EncryptionOptions(keystore, keystore_password, truststore, truststore_password, cipher_suites, + protocol, algorithm, store_type, require_client_auth, require_endpoint_verification, + enabled, optional); + } + + public EncryptionOptions withOptional(boolean optional) + { + return new EncryptionOptions(keystore, keystore_password, truststore, truststore_password, cipher_suites, + protocol, algorithm, store_type, require_client_auth, require_endpoint_verification, + enabled, optional); + } + /** * The method is being mainly used to cache SslContexts therefore, we only consider * fields that would make a difference when the TrustStore or KeyStore files are updated @@ -81,7 +203,7 @@ public boolean equals(Object o) Objects.equals(protocol, opt.protocol) && Objects.equals(algorithm, opt.algorithm) && Objects.equals(store_type, opt.store_type) && - Arrays.equals(cipher_suites, opt.cipher_suites); + Objects.equals(cipher_suites, opt.cipher_suites); } /** @@ -101,7 +223,7 @@ public int hashCode() result += 31 * (store_type == null ? 0 : store_type.hashCode()); result += 31 * Boolean.hashCode(enabled); result += 31 * Boolean.hashCode(optional); - result += 31 * Arrays.hashCode(cipher_suites); + result += 31 * (cipher_suites == null ? 0 : cipher_suites.hashCode()); result += 31 * Boolean.hashCode(require_client_auth); result += 31 * Boolean.hashCode(require_endpoint_verification); return result; @@ -114,20 +236,156 @@ public enum InternodeEncryption all, none, dc, rack } - public InternodeEncryption internode_encryption = InternodeEncryption.none; - public boolean enable_legacy_ssl_storage_port = false; + public final InternodeEncryption internode_encryption; + public final boolean enable_legacy_ssl_storage_port; public ServerEncryptionOptions() - { } + { + this.internode_encryption = InternodeEncryption.none; + this.enable_legacy_ssl_storage_port = false; + } + public ServerEncryptionOptions(String keystore, String keystore_password, String truststore, String truststore_password, List cipher_suites, String protocol, String algorithm, String store_type, boolean require_client_auth, boolean require_endpoint_verification, boolean enabled, boolean optional, InternodeEncryption internode_encryption, boolean enable_legacy_ssl_storage_port) + { + super(keystore, keystore_password, truststore, truststore_password, cipher_suites, protocol, algorithm, store_type, require_client_auth, require_endpoint_verification, enabled, optional); + this.internode_encryption = internode_encryption; + this.enable_legacy_ssl_storage_port = enable_legacy_ssl_storage_port; + } - /** - * Copy constructor - */ public ServerEncryptionOptions(ServerEncryptionOptions options) { super(options); - internode_encryption = options.internode_encryption; - enable_legacy_ssl_storage_port = options.enable_legacy_ssl_storage_port; + this.internode_encryption = options.internode_encryption; + this.enable_legacy_ssl_storage_port = options.enable_legacy_ssl_storage_port; } + + public boolean shouldEncrypt(InetAddressAndPort endpoint) + { + IEndpointSnitch snitch = DatabaseDescriptor.getEndpointSnitch(); + switch (internode_encryption) + { + case none: + return false; // if nothing needs to be encrypted then return immediately. + case all: + break; + case dc: + if (snitch.getDatacenter(endpoint).equals(snitch.getLocalDatacenter())) + return false; + break; + case rack: + // for rack then check if the DC's are the same. + if (snitch.getRack(endpoint).equals(snitch.getLocalRack()) + && snitch.getDatacenter(endpoint).equals(snitch.getLocalDatacenter())) + return false; + break; + } + return true; + } + + + public ServerEncryptionOptions withKeyStore(String keystore) + { + return new ServerEncryptionOptions(keystore, keystore_password, truststore, truststore_password, cipher_suites, + protocol, algorithm, store_type, require_client_auth, require_endpoint_verification, + enabled, optional, internode_encryption, enable_legacy_ssl_storage_port); + } + + public ServerEncryptionOptions withKeyStorePassword(String keystore_password) + { + return new ServerEncryptionOptions(keystore, keystore_password, truststore, truststore_password, cipher_suites, + protocol, algorithm, store_type, require_client_auth, require_endpoint_verification, + enabled, optional, internode_encryption, enable_legacy_ssl_storage_port); + } + + public ServerEncryptionOptions withTrustStore(String truststore) + { + return new ServerEncryptionOptions(keystore, keystore_password, truststore, truststore_password, cipher_suites, + protocol, algorithm, store_type, require_client_auth, require_endpoint_verification, + enabled, optional, internode_encryption, enable_legacy_ssl_storage_port); + } + + public ServerEncryptionOptions withTrustStorePassword(String truststore_password) + { + return new ServerEncryptionOptions(keystore, keystore_password, truststore, truststore_password, cipher_suites, + protocol, algorithm, store_type, require_client_auth, require_endpoint_verification, + enabled, optional, internode_encryption, enable_legacy_ssl_storage_port); + } + + public ServerEncryptionOptions withCipherSuites(List cipher_suites) + { + return new ServerEncryptionOptions(keystore, keystore_password, truststore, truststore_password, cipher_suites, + protocol, algorithm, store_type, require_client_auth, require_endpoint_verification, + enabled, optional, internode_encryption, enable_legacy_ssl_storage_port); + } + + public ServerEncryptionOptions withCipherSuites(String ... cipher_suites) + { + return new ServerEncryptionOptions(keystore, keystore_password, truststore, truststore_password, ImmutableList.copyOf(cipher_suites), + protocol, algorithm, store_type, require_client_auth, require_endpoint_verification, + enabled, optional, internode_encryption, enable_legacy_ssl_storage_port); + } + + public ServerEncryptionOptions withProtocol(String protocol) + { + return new ServerEncryptionOptions(keystore, keystore_password, truststore, truststore_password, cipher_suites, + protocol, algorithm, store_type, require_client_auth, require_endpoint_verification, + enabled, optional, internode_encryption, enable_legacy_ssl_storage_port); + } + + public ServerEncryptionOptions withAlgorithm(String algorithm) + { + return new ServerEncryptionOptions(keystore, keystore_password, truststore, truststore_password, cipher_suites, + protocol, algorithm, store_type, require_client_auth, require_endpoint_verification, + enabled, optional, internode_encryption, enable_legacy_ssl_storage_port); + } + + public ServerEncryptionOptions withStoreType(String store_type) + { + return new ServerEncryptionOptions(keystore, keystore_password, truststore, truststore_password, cipher_suites, + protocol, algorithm, store_type, require_client_auth, require_endpoint_verification, + enabled, optional, internode_encryption, enable_legacy_ssl_storage_port); + } + + public ServerEncryptionOptions withRequireClientAuth(boolean require_client_auth) + { + return new ServerEncryptionOptions(keystore, keystore_password, truststore, truststore_password, cipher_suites, + protocol, algorithm, store_type, require_client_auth, require_endpoint_verification, + enabled, optional, internode_encryption, enable_legacy_ssl_storage_port); + } + + public ServerEncryptionOptions withRequireEndpointVerification(boolean require_endpoint_verification) + { + return new ServerEncryptionOptions(keystore, keystore_password, truststore, truststore_password, cipher_suites, + protocol, algorithm, store_type, require_client_auth, require_endpoint_verification, + enabled, optional, internode_encryption, enable_legacy_ssl_storage_port); + } + + public ServerEncryptionOptions withEnabled(boolean enabled) + { + return new ServerEncryptionOptions(keystore, keystore_password, truststore, truststore_password, cipher_suites, + protocol, algorithm, store_type, require_client_auth, require_endpoint_verification, + enabled, optional, internode_encryption, enable_legacy_ssl_storage_port); + } + + public ServerEncryptionOptions withOptional(boolean optional) + { + return new ServerEncryptionOptions(keystore, keystore_password, truststore, truststore_password, cipher_suites, + protocol, algorithm, store_type, require_client_auth, require_endpoint_verification, + enabled, optional, internode_encryption, enable_legacy_ssl_storage_port); + } + + public ServerEncryptionOptions withInternodeEncryption(InternodeEncryption internode_encryption) + { + return new ServerEncryptionOptions(keystore, keystore_password, truststore, truststore_password, cipher_suites, + protocol, algorithm, store_type, require_client_auth, require_endpoint_verification, + enabled, optional, internode_encryption, enable_legacy_ssl_storage_port); + } + + public ServerEncryptionOptions withLegacySslStoragePort(boolean enable_legacy_ssl_storage_port) + { + return new ServerEncryptionOptions(keystore, keystore_password, truststore, truststore_password, cipher_suites, + protocol, algorithm, store_type, require_client_auth, require_endpoint_verification, + enabled, optional, internode_encryption, enable_legacy_ssl_storage_port); + } + } } diff --git a/src/java/org/apache/cassandra/cql3/statements/schema/AlterTableStatement.java b/src/java/org/apache/cassandra/cql3/statements/schema/AlterTableStatement.java index c348cc4f8ace..6410e67a70e6 100644 --- a/src/java/org/apache/cassandra/cql3/statements/schema/AlterTableStatement.java +++ b/src/java/org/apache/cassandra/cql3/statements/schema/AlterTableStatement.java @@ -166,7 +166,7 @@ private void addColumn(KeyspaceMetadata keyspace, { // After #8099, not safe to re-add columns of incompatible types - until *maybe* deser logic with dropped // columns is pushed deeper down the line. The latter would still be problematic in cases of schema races. - if (!droppedColumn.type.isValueCompatibleWith(type)) + if (!type.isValueCompatibleWith(droppedColumn.type)) { throw ire("Cannot re-add previously dropped column '%s' of type %s, incompatible with previous type %s", name, diff --git a/src/java/org/apache/cassandra/db/AbstractCompactionController.java b/src/java/org/apache/cassandra/db/AbstractCompactionController.java new file mode 100644 index 000000000000..99193f8626a5 --- /dev/null +++ b/src/java/org/apache/cassandra/db/AbstractCompactionController.java @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.db; + +import java.util.function.LongPredicate; + +import org.apache.cassandra.db.rows.UnfilteredRowIterator; +import org.apache.cassandra.schema.CompactionParams; + +/** + * AbstractCompactionController allows custom implementations of the CompactionController for use in tooling, without being tied to the SSTableReader and local filesystem + */ +public abstract class AbstractCompactionController implements AutoCloseable +{ + public final ColumnFamilyStore cfs; + public final int gcBefore; + public final CompactionParams.TombstoneOption tombstoneOption; + + public AbstractCompactionController(final ColumnFamilyStore cfs, final int gcBefore, CompactionParams.TombstoneOption tombstoneOption) + { + assert cfs != null; + this.cfs = cfs; + this.gcBefore = gcBefore; + this.tombstoneOption = tombstoneOption; + } + + public abstract boolean compactingRepaired(); + + public String getKeyspace() + { + return cfs.keyspace.getName(); + } + + public String getColumnFamily() + { + return cfs.name; + } + + public Iterable shadowSources(DecoratedKey key, boolean tombstoneOnly) + { + return null; + } + + public abstract LongPredicate getPurgeEvaluator(DecoratedKey key); +} diff --git a/src/java/org/apache/cassandra/db/ColumnFamilyStore.java b/src/java/org/apache/cassandra/db/ColumnFamilyStore.java index c09b88477f43..f20d74c1d702 100644 --- a/src/java/org/apache/cassandra/db/ColumnFamilyStore.java +++ b/src/java/org/apache/cassandra/db/ColumnFamilyStore.java @@ -87,6 +87,9 @@ import org.json.simple.JSONArray; import org.json.simple.JSONObject; +import static java.util.concurrent.TimeUnit.NANOSECONDS; +import static org.apache.cassandra.utils.ExecutorUtils.awaitTermination; +import static org.apache.cassandra.utils.ExecutorUtils.shutdown; import static org.apache.cassandra.utils.Throwables.maybeFail; public class ColumnFamilyStore implements ColumnFamilyStoreMBean @@ -217,31 +220,18 @@ public class ColumnFamilyStore implements ColumnFamilyStoreMBean private volatile boolean neverPurgeTombstones = false; - public static void shutdownFlushExecutor() throws InterruptedException - { - flushExecutor.shutdown(); - flushExecutor.awaitTermination(60, TimeUnit.SECONDS); - } - - public static void shutdownPostFlushExecutor() throws InterruptedException { postFlushExecutor.shutdown(); postFlushExecutor.awaitTermination(60, TimeUnit.SECONDS); } - public static void shutdownReclaimExecutor() throws InterruptedException - { - reclaimExecutor.shutdown(); - reclaimExecutor.awaitTermination(60, TimeUnit.SECONDS); - } - - public static void shutdownPerDiskFlushExecutors() throws InterruptedException + public static void shutdownExecutorsAndWait(long timeout, TimeUnit unit) throws InterruptedException, TimeoutException { - for (ExecutorService executorService : perDiskflushExecutors) - executorService.shutdown(); - for (ExecutorService executorService : perDiskflushExecutors) - executorService.awaitTermination(60, TimeUnit.SECONDS); + List executors = new ArrayList<>(perDiskflushExecutors.length + 3); + Collections.addAll(executors, reclaimExecutor, postFlushExecutor, flushExecutor); + Collections.addAll(executors, perDiskflushExecutors); + ExecutorUtils.shutdownAndWait(timeout, unit, executors); } public void reload() @@ -295,7 +285,7 @@ protected void runMayThrow() else { // we'll be rescheduled by the constructor of the Memtable. - forceFlush(); + forceFlushToSSTable(); } } } @@ -401,8 +391,8 @@ public ColumnFamilyStore(Keyspace keyspace, viewManager = keyspace.viewManager.forTable(metadata.id); metric = new TableMetrics(this); fileIndexGenerator.set(generation); - sampleReadLatencyNanos = TimeUnit.MILLISECONDS.toNanos(DatabaseDescriptor.getReadRpcTimeout() / 2); - additionalWriteLatencyNanos = TimeUnit.MILLISECONDS.toNanos(DatabaseDescriptor.getWriteRpcTimeout() / 2); + sampleReadLatencyNanos = DatabaseDescriptor.getReadRpcTimeout(NANOSECONDS) / 2; + additionalWriteLatencyNanos = DatabaseDescriptor.getWriteRpcTimeout(NANOSECONDS) / 2; logger.info("Initializing {}.{}", keyspace.getName(), name); @@ -868,7 +858,7 @@ private void logFlush() * @return a Future yielding the commit log position that can be guaranteed to have been successfully written * to sstables for this table once the future completes */ - public ListenableFuture forceFlush() + public ListenableFuture forceFlushToSSTable() { synchronized (data) { @@ -887,7 +877,7 @@ public ListenableFuture forceFlush() * @return a Future yielding the commit log position that can be guaranteed to have been successfully written * to sstables for this table once the future completes */ - public ListenableFuture forceFlush(CommitLogPosition flushIfDirtyBefore) + public ListenableFuture forceFlushToSSTable(CommitLogPosition flushIfDirtyBefore) { // we don't loop through the remaining memtables since here we only care about commit log dirtiness // and this does not vary between a table and its table-backed indexes @@ -914,9 +904,9 @@ private ListenableFuture waitForFlushes() return task; } - public CommitLogPosition forceBlockingFlush() + public CommitLogPosition forceBlockingFlushToSSTable() { - return FBUtilities.waitOnFuture(forceFlush()); + return FBUtilities.waitOnFuture(forceFlushToSSTable()); } /** @@ -1896,7 +1886,7 @@ public Set snapshot(String snapshotName, Predicate { if (!skipFlush) { - forceBlockingFlush(); + forceBlockingFlushToSSTable(); } return snapshotWithoutFlush(snapshotName, predicate, ephemeral); } @@ -2120,7 +2110,7 @@ public void truncateBlocking() if (keyspace.getMetadata().params.durableWrites || DatabaseDescriptor.isAutoSnapshot()) { - replayAfter = forceBlockingFlush(); + replayAfter = forceBlockingFlushToSSTable(); viewManager.forceBlockingFlush(); } else @@ -2390,14 +2380,14 @@ private void validateCompactionThresholds(int minThreshold, int maxThreshold) // End JMX get/set. - public int getMeanColumns() + public int getMeanEstimatedCellPerPartitionCount() { long sum = 0; long count = 0; for (SSTableReader sstable : getSSTables(SSTableSet.CANONICAL)) { - long n = sstable.getEstimatedColumnCount().count(); - sum += sstable.getEstimatedColumnCount().mean() * n; + long n = sstable.getEstimatedCellPerPartitionCount().count(); + sum += sstable.getEstimatedCellPerPartitionCount().mean() * n; count += n; } return count > 0 ? (int) (sum / count) : 0; @@ -2416,6 +2406,19 @@ public double getMeanPartitionSize() return count > 0 ? sum * 1.0 / count : 0; } + public int getMeanRowCount() + { + long totalRows = 0; + long totalPartitions = 0; + for (SSTableReader sstable : getSSTables(SSTableSet.CANONICAL)) + { + totalPartitions += sstable.getEstimatedPartitionSize().count(); + totalRows += sstable.getTotalRows(); + } + + return totalPartitions > 0 ? (int) (totalRows / totalPartitions) : 0; + } + public long estimateKeys() { long n = 0; @@ -2556,7 +2559,7 @@ public double getDroppableTombstoneRatio() for (SSTableReader sstable : getSSTables(SSTableSet.LIVE)) { allDroppable += sstable.getDroppableTombstonesBefore(localTime - metadata().params.gcGraceSeconds); - allColumns += sstable.getEstimatedColumnCount().mean() * sstable.getEstimatedColumnCount().count(); + allColumns += sstable.getEstimatedCellPerPartitionCount().mean() * sstable.getEstimatedCellPerPartitionCount().count(); } return allColumns > 0 ? allDroppable / allColumns : 0; } diff --git a/src/java/org/apache/cassandra/db/Columns.java b/src/java/org/apache/cassandra/db/Columns.java index 50817399ca93..bf9e17492dda 100644 --- a/src/java/org/apache/cassandra/db/Columns.java +++ b/src/java/org/apache/cassandra/db/Columns.java @@ -28,6 +28,7 @@ import com.google.common.hash.Hasher; import net.nicoulaj.compilecommand.annotations.DontInline; +import org.apache.cassandra.exceptions.UnknownColumnException; import org.apache.cassandra.schema.ColumnMetadata; import org.apache.cassandra.schema.TableMetadata; import org.apache.cassandra.cql3.ColumnIdentifier; @@ -454,7 +455,7 @@ public Columns deserialize(DataInputPlus in, TableMetadata metadata) throws IOEx // deserialization. The column will be ignore later on anyway. column = metadata.getDroppedColumn(name); if (column == null) - throw new RuntimeException("Unknown column " + UTF8Type.instance.getString(name) + " during deserialization"); + throw new UnknownColumnException("Unknown column " + UTF8Type.instance.getString(name) + " during deserialization"); } builder.add(column); } diff --git a/src/java/org/apache/cassandra/db/CounterMutation.java b/src/java/org/apache/cassandra/db/CounterMutation.java index d04ddd8b909f..bb10a6a7a956 100644 --- a/src/java/org/apache/cassandra/db/CounterMutation.java +++ b/src/java/org/apache/cassandra/db/CounterMutation.java @@ -24,7 +24,6 @@ import com.google.common.base.Function; import com.google.common.base.Objects; -import com.google.common.collect.ImmutableMap; import com.google.common.collect.Iterables; import com.google.common.collect.Iterators; import com.google.common.collect.PeekingIterator; @@ -39,14 +38,14 @@ import org.apache.cassandra.io.IVersionedSerializer; import org.apache.cassandra.io.util.DataInputPlus; import org.apache.cassandra.io.util.DataOutputPlus; -import org.apache.cassandra.net.MessageOut; -import org.apache.cassandra.net.MessagingService; import org.apache.cassandra.schema.TableId; import org.apache.cassandra.service.CacheService; import org.apache.cassandra.tracing.Tracing; import org.apache.cassandra.utils.*; import org.apache.cassandra.utils.btree.BTreeSet; +import static java.util.concurrent.TimeUnit.*; + public class CounterMutation implements IMutation { public static final CounterMutationSerializer serializer = new CounterMutationSerializer(); @@ -92,11 +91,6 @@ public ConsistencyLevel consistency() return consistency; } - public MessageOut makeMutationMessage() - { - return new MessageOut<>(MessagingService.Verb.COUNTER_MUTATION, this, serializer); - } - /** * Applies the counter mutation, returns the result Mutation (for replication to other nodes). * @@ -146,10 +140,10 @@ private void grabCounterLocks(Keyspace keyspace, List locks) throws WriteT for (Lock lock : LOCKS.bulkGet(getCounterLockKeys())) { - long timeout = TimeUnit.MILLISECONDS.toNanos(getTimeout()) - (System.nanoTime() - startTime); + long timeout = getTimeout(NANOSECONDS) - (System.nanoTime() - startTime); try { - if (!lock.tryLock(timeout, TimeUnit.NANOSECONDS)) + if (!lock.tryLock(timeout, NANOSECONDS)) throw new WriteTimeoutException(WriteType.COUNTER, consistency(), 0, consistency().blockFor(keyspace)); locks.add(lock); } @@ -309,9 +303,9 @@ private void updateForRow(PeekingIterator markIter, } } - public long getTimeout() + public long getTimeout(TimeUnit unit) { - return DatabaseDescriptor.getCounterWriteRpcTimeout(); + return DatabaseDescriptor.getCounterWriteRpcTimeout(unit); } @Override diff --git a/src/java/org/apache/cassandra/db/CounterMutationVerbHandler.java b/src/java/org/apache/cassandra/db/CounterMutationVerbHandler.java index c946ea595fef..a30ce665beba 100644 --- a/src/java/org/apache/cassandra/db/CounterMutationVerbHandler.java +++ b/src/java/org/apache/cassandra/db/CounterMutationVerbHandler.java @@ -22,16 +22,17 @@ import org.apache.cassandra.config.DatabaseDescriptor; import org.apache.cassandra.net.IVerbHandler; -import org.apache.cassandra.net.MessageIn; +import org.apache.cassandra.net.Message; import org.apache.cassandra.net.MessagingService; import org.apache.cassandra.service.StorageProxy; -import org.apache.cassandra.utils.FBUtilities; public class CounterMutationVerbHandler implements IVerbHandler { + public static final CounterMutationVerbHandler instance = new CounterMutationVerbHandler(); + private static final Logger logger = LoggerFactory.getLogger(CounterMutationVerbHandler.class); - public void doVerb(final MessageIn message, final int id) + public void doVerb(final Message message) { long queryStartNanoTime = System.nanoTime(); final CounterMutation cm = message.payload; @@ -45,12 +46,9 @@ public void doVerb(final MessageIn message, final int id) // will not be called if the request timeout, but this is ok // because the coordinator of the counter mutation will timeout on // it's own in that case. - StorageProxy.applyCounterMutationOnLeader(cm, localDataCenter, new Runnable() - { - public void run() - { - MessagingService.instance().sendReply(WriteResponse.createMessage(), id, message.from); - } - }, queryStartNanoTime); + StorageProxy.applyCounterMutationOnLeader(cm, + localDataCenter, + () -> MessagingService.instance().send(message.emptyResponse(), message.from()), + queryStartNanoTime); } } diff --git a/src/java/org/apache/cassandra/db/IMutation.java b/src/java/org/apache/cassandra/db/IMutation.java index 9eaf19b4922f..1710cfd12379 100644 --- a/src/java/org/apache/cassandra/db/IMutation.java +++ b/src/java/org/apache/cassandra/db/IMutation.java @@ -18,6 +18,7 @@ package org.apache.cassandra.db; import java.util.Collection; +import java.util.concurrent.TimeUnit; import org.apache.cassandra.db.partitions.PartitionUpdate; import org.apache.cassandra.schema.TableId; @@ -28,7 +29,7 @@ public interface IMutation public String getKeyspaceName(); public Collection getTableIds(); public DecoratedKey key(); - public long getTimeout(); + public long getTimeout(TimeUnit unit); public String toString(boolean shallow); public Collection getPartitionUpdates(); diff --git a/src/java/org/apache/cassandra/db/Keyspace.java b/src/java/org/apache/cassandra/db/Keyspace.java index bc382eef82b8..fa1f7c87debb 100644 --- a/src/java/org/apache/cassandra/db/Keyspace.java +++ b/src/java/org/apache/cassandra/db/Keyspace.java @@ -56,6 +56,9 @@ import org.apache.cassandra.utils.*; import org.apache.cassandra.utils.concurrent.OpOrder; +import static java.util.concurrent.TimeUnit.*; +import static org.apache.cassandra.utils.MonotonicClock.approxTime; + /** * It represents a Keyspace. */ @@ -395,7 +398,7 @@ public void dropCf(TableId tableId) // disassociate a cfs from this keyspace instance. private void unloadCf(ColumnFamilyStore cfs) { - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); cfs.invalidate(); } @@ -544,7 +547,7 @@ private CompletableFuture applyInternal(final Mutation mutation, if (lock == null) { //throw WTE only if request is droppable - if (isDroppable && (System.currentTimeMillis() - mutation.createdAt) > DatabaseDescriptor.getWriteRpcTimeout()) + if (isDroppable && (approxTime.isAfter(mutation.approxCreatedAtNanos + DatabaseDescriptor.getWriteRpcTimeout(NANOSECONDS)))) { for (int j = 0; j < i; j++) locks[j].unlock(); @@ -605,7 +608,7 @@ else if (isDeferrable) if (isDroppable) { for(TableId tableId : tableIds) - columnFamilyStores.get(tableId).metric.viewLockAcquireTime.update(acquireTime, TimeUnit.MILLISECONDS); + columnFamilyStores.get(tableId).metric.viewLockAcquireTime.update(acquireTime, MILLISECONDS); } } int nowInSec = FBUtilities.nowInSeconds(); @@ -671,7 +674,7 @@ public List> flush() { List> futures = new ArrayList<>(columnFamilyStores.size()); for (ColumnFamilyStore cfs : columnFamilyStores.values()) - futures.add(cfs.forceFlush()); + futures.add(cfs.forceFlushToSSTable()); return futures; } diff --git a/src/java/org/apache/cassandra/db/Mutation.java b/src/java/org/apache/cassandra/db/Mutation.java index 6195fe4c87a5..22c4ed83766f 100644 --- a/src/java/org/apache/cassandra/db/Mutation.java +++ b/src/java/org/apache/cassandra/db/Mutation.java @@ -20,6 +20,7 @@ import java.io.IOException; import java.util.*; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicLong; import com.google.common.collect.ImmutableCollection; @@ -32,13 +33,13 @@ import org.apache.cassandra.io.IVersionedSerializer; import org.apache.cassandra.io.util.DataInputPlus; import org.apache.cassandra.io.util.DataOutputPlus; -import org.apache.cassandra.net.MessageOut; -import org.apache.cassandra.net.MessagingService; import org.apache.cassandra.schema.Schema; import org.apache.cassandra.schema.TableId; import org.apache.cassandra.schema.TableMetadata; import org.apache.cassandra.utils.ByteBufferUtil; +import static org.apache.cassandra.utils.MonotonicClock.approxTime; + public class Mutation implements IMutation { public static final MutationSerializer serializer = new MutationSerializer(); @@ -52,7 +53,7 @@ public class Mutation implements IMutation private final ImmutableMap modifications; // Time at which this mutation or the builder that built it was instantiated - final long createdAt; + final long approxCreatedAtNanos; // keep track of when mutation has started waiting for a MV partition lock final AtomicLong viewLockAcquireStart = new AtomicLong(0); @@ -60,10 +61,10 @@ public class Mutation implements IMutation public Mutation(PartitionUpdate update) { - this(update.metadata().keyspace, update.partitionKey(), ImmutableMap.of(update.metadata().id, update), System.currentTimeMillis()); + this(update.metadata().keyspace, update.partitionKey(), ImmutableMap.of(update.metadata().id, update), approxTime.now()); } - public Mutation(String keyspaceName, DecoratedKey key, ImmutableMap modifications, long createdAt) + public Mutation(String keyspaceName, DecoratedKey key, ImmutableMap modifications, long approxCreatedAtNanos) { this.keyspaceName = keyspaceName; this.key = key; @@ -73,7 +74,7 @@ public Mutation(String keyspaceName, DecoratedKey key, ImmutableMap tableIds) @@ -90,7 +91,7 @@ public Mutation without(Set tableIds) } } - return new Mutation(keyspaceName, key, builder.build(), createdAt); + return new Mutation(keyspaceName, key, builder.build(), approxCreatedAtNanos); } public Mutation without(TableId tableId) @@ -177,7 +178,7 @@ public static Mutation merge(List mutations) modifications.put(table, updates.size() == 1 ? updates.get(0) : PartitionUpdate.merge(updates)); updates.clear(); } - return new Mutation(ks, key, modifications.build(), System.currentTimeMillis()); + return new Mutation(ks, key, modifications.build(), approxTime.now()); } public CompletableFuture applyFuture() @@ -210,19 +211,9 @@ public void applyUnsafe() apply(false); } - public MessageOut createMessage() - { - return createMessage(MessagingService.Verb.MUTATION); - } - - public MessageOut createMessage(MessagingService.Verb verb) - { - return new MessageOut<>(verb, this, serializer); - } - - public long getTimeout() + public long getTimeout(TimeUnit unit) { - return DatabaseDescriptor.getWriteRpcTimeout(); + return DatabaseDescriptor.getWriteRpcTimeout(unit); } public int smallestGCGS() @@ -363,7 +354,7 @@ public Mutation deserialize(DataInputPlus in, int version, SerializationHelper.F update = PartitionUpdate.serializer.deserialize(in, version, flag); modifications.put(update.metadata().id, update); } - return new Mutation(update.metadata().keyspace, dk, modifications.build(), System.currentTimeMillis()); + return new Mutation(update.metadata().keyspace, dk, modifications.build(), approxTime.now()); } public Mutation deserialize(DataInputPlus in, int version) throws IOException @@ -389,7 +380,7 @@ public static class PartitionUpdateCollector private final ImmutableMap.Builder modifications = new ImmutableMap.Builder<>(); private final String keyspaceName; private final DecoratedKey key; - private final long createdAt = System.currentTimeMillis(); + private final long approxCreatedAtNanos = approxTime.now(); private boolean empty = true; public PartitionUpdateCollector(String keyspaceName, DecoratedKey key) @@ -425,7 +416,7 @@ public boolean isEmpty() public Mutation build() { - return new Mutation(keyspaceName, key, modifications.build(), createdAt); + return new Mutation(keyspaceName, key, modifications.build(), approxCreatedAtNanos); } } } diff --git a/src/java/org/apache/cassandra/db/MutationVerbHandler.java b/src/java/org/apache/cassandra/db/MutationVerbHandler.java index 9660f658dd35..bcb9cc7aaee3 100644 --- a/src/java/org/apache/cassandra/db/MutationVerbHandler.java +++ b/src/java/org/apache/cassandra/db/MutationVerbHandler.java @@ -17,8 +17,6 @@ */ package org.apache.cassandra.db; -import java.util.Iterator; - import org.apache.cassandra.exceptions.WriteTimeoutException; import org.apache.cassandra.locator.InetAddressAndPort; import org.apache.cassandra.net.*; @@ -26,10 +24,12 @@ public class MutationVerbHandler implements IVerbHandler { - private void reply(int id, InetAddressAndPort replyTo) + public static final MutationVerbHandler instance = new MutationVerbHandler(); + + private void respond(Message respondTo, InetAddressAndPort respondToAddress) { - Tracing.trace("Enqueuing response to {}", replyTo); - MessagingService.instance().sendReply(WriteResponse.createMessage(), id, replyTo); + Tracing.trace("Enqueuing response to {}", respondToAddress); + MessagingService.instance().send(respondTo.emptyResponse(), respondToAddress); } private void failed() @@ -37,27 +37,25 @@ private void failed() Tracing.trace("Payload application resulted in WriteTimeout, not replying"); } - public void doVerb(MessageIn message, int id) + public void doVerb(Message message) { // Check if there were any forwarding headers in this message - InetAddressAndPort from = (InetAddressAndPort)message.parameters.get(ParameterType.FORWARD_FROM); - InetAddressAndPort replyTo; + InetAddressAndPort from = message.respondTo(); + InetAddressAndPort respondToAddress; if (from == null) { - replyTo = message.from; - ForwardToContainer forwardTo = (ForwardToContainer)message.parameters.get(ParameterType.FORWARD_TO); - if (forwardTo != null) - forwardToLocalNodes(message.payload, message.verb, forwardTo, message.from); + respondToAddress = message.from(); + ForwardingInfo forwardTo = message.forwardTo(); + if (forwardTo != null) forwardToLocalNodes(message, forwardTo); } else { - - replyTo = from; + respondToAddress = from; } try { - message.payload.applyFuture().thenAccept(o -> reply(id, replyTo)).exceptionally(wto -> { + message.payload.applyFuture().thenAccept(o -> respond(message, respondToAddress)).exceptionally(wto -> { failed(); return null; }); @@ -68,17 +66,21 @@ public void doVerb(MessageIn message, int id) } } - private static void forwardToLocalNodes(Mutation mutation, MessagingService.Verb verb, ForwardToContainer forwardTo, InetAddressAndPort from) + private static void forwardToLocalNodes(Message originalMessage, ForwardingInfo forwardTo) { - // tell the recipients who to send their ack to - MessageOut message = new MessageOut<>(verb, mutation, Mutation.serializer).withParameter(ParameterType.FORWARD_FROM, from); - Iterator iterator = forwardTo.targets.iterator(); - // Send a message to each of the addresses on our Forward List - for (int i = 0; i < forwardTo.targets.size(); i++) + Message.Builder builder = + Message.builder(originalMessage) + .withParam(ParamType.RESPOND_TO, originalMessage.from()) + .withoutParam(ParamType.FORWARD_TO); + + boolean useSameMessageID = forwardTo.useSameMessageID(); + // reuse the same Message if all ids are identical (as they will be for 4.0+ node originated messages) + Message message = useSameMessageID ? builder.build() : null; + + forwardTo.forEach((id, target) -> { - InetAddressAndPort address = iterator.next(); - Tracing.trace("Enqueuing forwarded write to {}", address); - MessagingService.instance().sendOneWay(message, forwardTo.messageIds[i], address); - } + Tracing.trace("Enqueuing forwarded write to {}", target); + MessagingService.instance().send(useSameMessageID ? message : builder.withId(id).build(), target); + }); } } diff --git a/src/java/org/apache/cassandra/db/PartitionRangeReadCommand.java b/src/java/org/apache/cassandra/db/PartitionRangeReadCommand.java index b5f6fb534214..2145389607a6 100644 --- a/src/java/org/apache/cassandra/db/PartitionRangeReadCommand.java +++ b/src/java/org/apache/cassandra/db/PartitionRangeReadCommand.java @@ -18,9 +18,12 @@ package org.apache.cassandra.db; import java.io.IOException; +import java.util.concurrent.TimeUnit; import com.google.common.annotations.VisibleForTesting; +import org.apache.cassandra.net.MessageFlag; +import org.apache.cassandra.net.Verb; import org.apache.cassandra.schema.TableMetadata; import org.apache.cassandra.config.DatabaseDescriptor; import org.apache.cassandra.db.filter.*; @@ -38,8 +41,7 @@ import org.apache.cassandra.io.util.DataInputPlus; import org.apache.cassandra.io.util.DataOutputPlus; import org.apache.cassandra.metrics.TableMetrics; -import org.apache.cassandra.net.MessageOut; -import org.apache.cassandra.net.MessagingService; +import org.apache.cassandra.net.Message; import org.apache.cassandra.schema.IndexMetadata; import org.apache.cassandra.service.ClientState; import org.apache.cassandra.service.StorageProxy; @@ -233,9 +235,9 @@ public PartitionRangeReadCommand withUpdatedLimitsAndDataRange(DataLimits newLim indexMetadata()); } - public long getTimeout() + public long getTimeout(TimeUnit unit) { - return DatabaseDescriptor.getRangeRpcTimeout(); + return DatabaseDescriptor.getRangeRpcTimeout(unit); } public PartitionIterator execute(ConsistencyLevel consistency, ClientState clientState, long queryStartNanoTime) throws RequestExecutionException @@ -345,9 +347,10 @@ public BaseRowIterator applyToPartition(BaseRowIterator iter) return Transformation.apply(iter, new CacheFilter()); } - public MessageOut createMessage() + @Override + public Verb verb() { - return new MessageOut<>(MessagingService.Verb.RANGE_SLICE, this, serializer); + return Verb.RANGE_REQ; } protected void appendCQLWhereClause(StringBuilder sb) @@ -414,6 +417,11 @@ public boolean isLimitedToOnePartition() && dataRange.startKey().equals(dataRange.stopKey()); } + public boolean isRangeRequest() + { + return true; + } + private static class Deserializer extends SelectionDeserializer { public ReadCommand deserialize(DataInputPlus in, diff --git a/src/java/org/apache/cassandra/db/ReadCommand.java b/src/java/org/apache/cassandra/db/ReadCommand.java index 32b91ad8905e..68ce2eacbf9b 100644 --- a/src/java/org/apache/cassandra/db/ReadCommand.java +++ b/src/java/org/apache/cassandra/db/ReadCommand.java @@ -20,6 +20,7 @@ import java.io.IOException; import java.nio.ByteBuffer; import java.util.*; +import java.util.concurrent.TimeUnit; import java.util.function.BiFunction; import java.util.function.LongPredicate; @@ -35,7 +36,9 @@ import org.apache.cassandra.config.*; import org.apache.cassandra.db.filter.*; -import org.apache.cassandra.db.monitoring.ApproximateTime; +import org.apache.cassandra.net.MessageFlag; +import org.apache.cassandra.net.Verb; +import org.apache.cassandra.utils.ApproximateTime; import org.apache.cassandra.db.partitions.*; import org.apache.cassandra.db.rows.*; import org.apache.cassandra.db.transform.RTBoundCloser; @@ -52,9 +55,8 @@ import org.apache.cassandra.io.util.DataInputPlus; import org.apache.cassandra.io.util.DataOutputPlus; import org.apache.cassandra.locator.Replica; -import org.apache.cassandra.locator.ReplicaCollection; import org.apache.cassandra.metrics.TableMetrics; -import org.apache.cassandra.net.MessageOut; +import org.apache.cassandra.net.Message; import org.apache.cassandra.schema.IndexMetadata; import org.apache.cassandra.schema.Schema; import org.apache.cassandra.schema.SchemaConstants; @@ -69,6 +71,7 @@ import static com.google.common.collect.Iterables.any; import static com.google.common.collect.Iterables.filter; +import static org.apache.cassandra.utils.MonotonicClock.approxTime; /** * General interface for storage-engine read commands (common to both range and @@ -164,6 +167,8 @@ protected ReadCommand(Kind kind, public abstract boolean isLimitedToOnePartition(); + public abstract boolean isRangeRequest(); + /** * Creates a new ReadCommand instance with new limits. * @@ -177,7 +182,7 @@ protected ReadCommand(Kind kind, * * @return the configured timeout for this command. */ - public abstract long getTimeout(); + public abstract long getTimeout(TimeUnit unit); /** * Whether this query is a digest one or not. @@ -628,14 +633,15 @@ protected Row applyToRow(Row row) private boolean maybeAbort() { /** - * The value returned by ApproximateTime.currentTimeMillis() is updated only every - * {@link ApproximateTime.CHECK_INTERVAL_MS}, by default 10 millis. Since MonitorableImpl - * relies on ApproximateTime, we don't need to check unless the approximate time has elapsed. + * TODO: this is not a great way to abort early; why not expressly limit checks to 10ms intervals? + * The value returned by approxTime.now() is updated only every + * {@link org.apache.cassandra.utils.MonotonicClock.SampledClock.CHECK_INTERVAL_MS}, by default 2 millis. Since MonitorableImpl + * relies on approxTime, we don't need to check unless the approximate time has elapsed. */ - if (lastChecked == ApproximateTime.currentTimeMillis()) + if (lastChecked == approxTime.now()) return false; - lastChecked = ApproximateTime.currentTimeMillis(); + lastChecked = approxTime.now(); if (isAborted()) { @@ -661,7 +667,14 @@ protected UnfilteredPartitionIterator withStateTracking(UnfilteredPartitionItera /** * Creates a message for this command. */ - public abstract MessageOut createMessage(); + public Message createMessage(boolean trackRepairedData) + { + return trackRepairedData + ? Message.outWithFlags(verb(), this, MessageFlag.CALL_BACK_ON_FAILURE, MessageFlag.TRACK_REPAIRED_DATA) + : Message.outWithFlag (verb(), this, MessageFlag.CALL_BACK_ON_FAILURE); + } + + public abstract Verb verb(); protected abstract void appendCQLWhereClause(StringBuilder sb); diff --git a/src/java/org/apache/cassandra/db/ReadCommandVerbHandler.java b/src/java/org/apache/cassandra/db/ReadCommandVerbHandler.java index e39e8a855205..2c28ed9d4b8f 100644 --- a/src/java/org/apache/cassandra/db/ReadCommandVerbHandler.java +++ b/src/java/org/apache/cassandra/db/ReadCommandVerbHandler.java @@ -20,29 +20,26 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.apache.cassandra.config.DatabaseDescriptor; import org.apache.cassandra.db.partitions.UnfilteredPartitionIterator; import org.apache.cassandra.dht.Token; import org.apache.cassandra.exceptions.InvalidRequestException; -import org.apache.cassandra.io.IVersionedSerializer; import org.apache.cassandra.locator.Replica; import org.apache.cassandra.net.IVerbHandler; -import org.apache.cassandra.net.MessageIn; -import org.apache.cassandra.net.MessageOut; +import org.apache.cassandra.net.Message; import org.apache.cassandra.net.MessagingService; -import org.apache.cassandra.net.ParameterType; import org.apache.cassandra.service.StorageService; import org.apache.cassandra.tracing.Tracing; +import static java.util.concurrent.TimeUnit.NANOSECONDS; + public class ReadCommandVerbHandler implements IVerbHandler { - private static final Logger logger = LoggerFactory.getLogger(ReadCommandVerbHandler.class); + public static final ReadCommandVerbHandler instance = new ReadCommandVerbHandler(); - protected IVersionedSerializer serializer() - { - return ReadResponse.serializer; - } + private static final Logger logger = LoggerFactory.getLogger(ReadCommandVerbHandler.class); - public void doVerb(MessageIn message, int id) + public void doVerb(Message message) { if (StorageService.instance.isBootstrapMode()) { @@ -52,9 +49,10 @@ public void doVerb(MessageIn message, int id) ReadCommand command = message.payload; validateTransientStatus(message); - command.setMonitoringTime(message.constructionTime, message.isCrossNode(), message.getTimeout(), message.getSlowQueryTimeout()); + long timeout = message.expiresAtNanos() - message.createdAtNanos(); + command.setMonitoringTime(message.createdAtNanos(), message.isCrossNode(), timeout, DatabaseDescriptor.getSlowQueryTimeout(NANOSECONDS)); - if (message.parameters.containsKey(ParameterType.TRACK_REPAIRED_DATA)) + if (message.trackRepairedData()) command.trackRepairedStatus(); ReadResponse response; @@ -66,17 +64,17 @@ public void doVerb(MessageIn message, int id) if (!command.complete()) { - Tracing.trace("Discarding partial response to {} (timed out)", message.from); - MessagingService.instance().incrementDroppedMessages(message, message.getLifetimeInMS()); + Tracing.trace("Discarding partial response to {} (timed out)", message.from()); + MessagingService.instance().metrics.recordDroppedMessage(message, message.elapsedSinceCreated(NANOSECONDS), NANOSECONDS); return; } - Tracing.trace("Enqueuing response to {}", message.from); - MessageOut reply = new MessageOut<>(MessagingService.Verb.REQUEST_RESPONSE, response, serializer()); - MessagingService.instance().sendReply(reply, id, message.from); + Tracing.trace("Enqueuing response to {}", message.from()); + Message reply = message.responseWith(response); + MessagingService.instance().send(reply, message.from()); } - private void validateTransientStatus(MessageIn message) + private void validateTransientStatus(Message message) { ReadCommand command = message.payload; Token token; @@ -93,14 +91,14 @@ private void validateTransientStatus(MessageIn message) if (replica == null) { logger.warn("Received a read request from {} for a range that is not owned by the current replica {}.", - message.from, + message.from(), command); return; } if (!command.acceptsTransient() && replica.isTransient()) { - MessagingService.instance().incrementDroppedMessages(message, message.getLifetimeInMS()); + MessagingService.instance().metrics.recordDroppedMessage(message, message.elapsedSinceCreated(NANOSECONDS), NANOSECONDS); throw new InvalidRequestException(String.format("Attempted to serve %s data request from %s node in %s", command.acceptsTransient() ? "transient" : "full", replica.isTransient() ? "transient" : "full", diff --git a/src/java/org/apache/cassandra/db/ReadExecutionController.java b/src/java/org/apache/cassandra/db/ReadExecutionController.java index 29b6fa7b0484..73ddad8022d8 100644 --- a/src/java/org/apache/cassandra/db/ReadExecutionController.java +++ b/src/java/org/apache/cassandra/db/ReadExecutionController.java @@ -21,9 +21,11 @@ import org.apache.cassandra.index.Index; import org.apache.cassandra.schema.TableMetadata; -import org.apache.cassandra.utils.Clock; +import org.apache.cassandra.utils.MonotonicClock; import org.apache.cassandra.utils.concurrent.OpOrder; +import static org.apache.cassandra.utils.MonotonicClock.preciseTime; + public class ReadExecutionController implements AutoCloseable { private static final long NO_SAMPLING = Long.MIN_VALUE; @@ -36,7 +38,7 @@ public class ReadExecutionController implements AutoCloseable private final ReadExecutionController indexController; private final WriteContext writeContext; private final ReadCommand command; - static Clock clock = Clock.instance; + static MonotonicClock clock = preciseTime; private final long createdAtNanos; // Only used while sampling @@ -93,7 +95,7 @@ static ReadExecutionController forCommand(ReadCommand command) ColumnFamilyStore baseCfs = Keyspace.openAndGetStore(command.metadata()); ColumnFamilyStore indexCfs = maybeGetIndexCfs(baseCfs, command); - long createdAtNanos = baseCfs.metric.topLocalReadQueryTime.isEnabled() ? clock.nanoTime() : NO_SAMPLING; + long createdAtNanos = baseCfs.metric.topLocalReadQueryTime.isEnabled() ? clock.now() : NO_SAMPLING; if (indexCfs == null) return new ReadExecutionController(command, baseCfs.readOrdering.start(), baseCfs.metadata(), null, null, createdAtNanos); @@ -172,7 +174,7 @@ public void close() private void addSample() { String cql = command.toCQLString(); - int timeMicros = (int) Math.min(TimeUnit.NANOSECONDS.toMicros(clock.nanoTime() - createdAtNanos), Integer.MAX_VALUE); + int timeMicros = (int) Math.min(TimeUnit.NANOSECONDS.toMicros(clock.now() - createdAtNanos), Integer.MAX_VALUE); ColumnFamilyStore cfs = ColumnFamilyStore.getIfExists(baseMetadata.id); if (cfs != null) cfs.metric.topLocalReadQueryTime.addSample(cql, timeMicros); diff --git a/src/java/org/apache/cassandra/db/ReadRepairVerbHandler.java b/src/java/org/apache/cassandra/db/ReadRepairVerbHandler.java index 2e499e7935c5..903b3d43bdf3 100644 --- a/src/java/org/apache/cassandra/db/ReadRepairVerbHandler.java +++ b/src/java/org/apache/cassandra/db/ReadRepairVerbHandler.java @@ -18,14 +18,16 @@ package org.apache.cassandra.db; import org.apache.cassandra.net.IVerbHandler; -import org.apache.cassandra.net.MessageIn; +import org.apache.cassandra.net.Message; import org.apache.cassandra.net.MessagingService; public class ReadRepairVerbHandler implements IVerbHandler { - public void doVerb(MessageIn message, int id) + public static final ReadRepairVerbHandler instance = new ReadRepairVerbHandler(); + + public void doVerb(Message message) { message.payload.apply(); - MessagingService.instance().sendReply(WriteResponse.createMessage(), id, message.from); + MessagingService.instance().send(message.emptyResponse(), message.from()); } } diff --git a/src/java/org/apache/cassandra/db/SerializationHeader.java b/src/java/org/apache/cassandra/db/SerializationHeader.java index deadf68785b1..2e5211c087e8 100644 --- a/src/java/org/apache/cassandra/db/SerializationHeader.java +++ b/src/java/org/apache/cassandra/db/SerializationHeader.java @@ -27,6 +27,7 @@ import org.apache.cassandra.db.marshal.TypeParser; import org.apache.cassandra.db.marshal.UTF8Type; import org.apache.cassandra.db.rows.*; +import org.apache.cassandra.exceptions.UnknownColumnException; import org.apache.cassandra.io.sstable.format.SSTableReader; import org.apache.cassandra.io.sstable.format.Version; import org.apache.cassandra.io.sstable.metadata.IMetadataComponentSerializer; @@ -292,7 +293,7 @@ public MetadataType getType() return MetadataType.HEADER; } - public SerializationHeader toHeader(TableMetadata metadata) + public SerializationHeader toHeader(TableMetadata metadata) throws UnknownColumnException { Map> typeMap = new HashMap<>(staticColumns.size() + regularColumns.size()); @@ -320,7 +321,7 @@ public SerializationHeader toHeader(TableMetadata metadata) // deserialization. The column will be ignore later on anyway. column = metadata.getDroppedColumn(name, isStatic); if (column == null) - throw new RuntimeException("Unknown column " + UTF8Type.instance.getString(name) + " during deserialization"); + throw new UnknownColumnException("Unknown column " + UTF8Type.instance.getString(name) + " during deserialization"); } builder.add(column); } diff --git a/src/java/org/apache/cassandra/db/SinglePartitionReadCommand.java b/src/java/org/apache/cassandra/db/SinglePartitionReadCommand.java index aec1a54cedb7..8c983aa164f8 100644 --- a/src/java/org/apache/cassandra/db/SinglePartitionReadCommand.java +++ b/src/java/org/apache/cassandra/db/SinglePartitionReadCommand.java @@ -20,10 +20,9 @@ import java.io.IOException; import java.nio.ByteBuffer; import java.util.*; +import java.util.concurrent.TimeUnit; import com.google.common.annotations.VisibleForTesting; -import com.google.common.base.Preconditions; -import com.google.common.collect.Iterables; import com.google.common.collect.Sets; import org.apache.cassandra.cache.IRowCacheEntry; @@ -43,12 +42,11 @@ import org.apache.cassandra.io.sstable.format.SSTableReadsListener; import org.apache.cassandra.io.util.DataInputPlus; import org.apache.cassandra.io.util.DataOutputPlus; -import org.apache.cassandra.locator.Replica; -import org.apache.cassandra.locator.ReplicaCollection; import org.apache.cassandra.metrics.TableMetrics; -import org.apache.cassandra.net.MessageOut; +import org.apache.cassandra.net.Message; +import org.apache.cassandra.net.MessageFlag; import org.apache.cassandra.net.MessagingService; -import org.apache.cassandra.net.ParameterType; +import org.apache.cassandra.net.Verb; import org.apache.cassandra.schema.ColumnMetadata; import org.apache.cassandra.schema.IndexMetadata; import org.apache.cassandra.schema.TableMetadata; @@ -364,9 +362,9 @@ public ClusteringIndexFilter clusteringIndexFilter(DecoratedKey key) return clusteringIndexFilter; } - public long getTimeout() + public long getTimeout(TimeUnit unit) { - return DatabaseDescriptor.getReadRpcTimeout(); + return DatabaseDescriptor.getReadRpcTimeout(unit); } @Override @@ -1040,9 +1038,10 @@ public String toString() nowInSec()); } - public MessageOut createMessage() + @Override + public Verb verb() { - return new MessageOut<>(MessagingService.Verb.READ, this, serializer); + return Verb.READ_REQ; } protected void appendCQLWhereClause(StringBuilder sb) @@ -1078,6 +1077,11 @@ public boolean isLimitedToOnePartition() return true; } + public boolean isRangeRequest() + { + return false; + } + /** * Groups multiple single partition read commands. */ diff --git a/src/java/org/apache/cassandra/db/SnapshotCommand.java b/src/java/org/apache/cassandra/db/SnapshotCommand.java index eb6f67a028e7..484db2fd7178 100644 --- a/src/java/org/apache/cassandra/db/SnapshotCommand.java +++ b/src/java/org/apache/cassandra/db/SnapshotCommand.java @@ -22,8 +22,8 @@ import org.apache.cassandra.io.IVersionedSerializer; import org.apache.cassandra.io.util.DataInputPlus; import org.apache.cassandra.io.util.DataOutputPlus; -import org.apache.cassandra.net.MessageOut; -import org.apache.cassandra.net.MessagingService; +import org.apache.cassandra.net.Message; +import org.apache.cassandra.net.Verb; public class SnapshotCommand { @@ -42,11 +42,6 @@ public SnapshotCommand(String keyspace, String columnFamily, String snapshotName this.clear_snapshot = clearSnapshot; } - public MessageOut createMessage() - { - return new MessageOut(MessagingService.Verb.SNAPSHOT, this, serializer); - } - @Override public String toString() { diff --git a/src/java/org/apache/cassandra/db/SystemKeyspace.java b/src/java/org/apache/cassandra/db/SystemKeyspace.java index d48f84fc3f0e..4c586f407452 100644 --- a/src/java/org/apache/cassandra/db/SystemKeyspace.java +++ b/src/java/org/apache/cassandra/db/SystemKeyspace.java @@ -688,16 +688,17 @@ public static synchronized void updateTokens(InetAddressAndPort ep, Collection serializer = new Serializer(); + + public final String keyspace; + public final String table; + + public TruncateRequest(String keyspace, String table) + { + this.keyspace = keyspace; + this.table = table; + } + + @Override + public String toString() + { + return String.format("TruncateRequest(keyspace='%s', table='%s')'", keyspace, table); + } + + private static class Serializer implements IVersionedSerializer + { + public void serialize(TruncateRequest request, DataOutputPlus out, int version) throws IOException + { + out.writeUTF(request.keyspace); + out.writeUTF(request.table); + } + + public TruncateRequest deserialize(DataInputPlus in, int version) throws IOException + { + String keyspace = in.readUTF(); + String table = in.readUTF(); + return new TruncateRequest(keyspace, table); + } + + public long serializedSize(TruncateRequest request, int version) + { + return TypeSizes.sizeof(request.keyspace) + TypeSizes.sizeof(request.table); + } + } +} diff --git a/src/java/org/apache/cassandra/db/TruncateResponse.java b/src/java/org/apache/cassandra/db/TruncateResponse.java index af4ed8f2f929..822c9ccea30f 100644 --- a/src/java/org/apache/cassandra/db/TruncateResponse.java +++ b/src/java/org/apache/cassandra/db/TruncateResponse.java @@ -22,8 +22,6 @@ import org.apache.cassandra.io.IVersionedSerializer; import org.apache.cassandra.io.util.DataInputPlus; import org.apache.cassandra.io.util.DataOutputPlus; -import org.apache.cassandra.net.MessageOut; -import org.apache.cassandra.net.MessagingService; /** * This message is sent back the truncate operation and basically specifies if @@ -44,11 +42,6 @@ public TruncateResponse(String keyspace, String columnFamily, boolean success) this.success = success; } - public MessageOut createMessage() - { - return new MessageOut(MessagingService.Verb.REQUEST_RESPONSE, this, serializer); - } - public static class TruncateResponseSerializer implements IVersionedSerializer { public void serialize(TruncateResponse tr, DataOutputPlus out, int version) throws IOException diff --git a/src/java/org/apache/cassandra/db/TruncateVerbHandler.java b/src/java/org/apache/cassandra/db/TruncateVerbHandler.java index c2fac6561c87..c605d1f20e78 100644 --- a/src/java/org/apache/cassandra/db/TruncateVerbHandler.java +++ b/src/java/org/apache/cassandra/db/TruncateVerbHandler.java @@ -22,21 +22,23 @@ import org.apache.cassandra.io.FSError; import org.apache.cassandra.net.IVerbHandler; -import org.apache.cassandra.net.MessageIn; +import org.apache.cassandra.net.Message; import org.apache.cassandra.net.MessagingService; import org.apache.cassandra.tracing.Tracing; -public class TruncateVerbHandler implements IVerbHandler +public class TruncateVerbHandler implements IVerbHandler { + public static final TruncateVerbHandler instance = new TruncateVerbHandler(); + private static final Logger logger = LoggerFactory.getLogger(TruncateVerbHandler.class); - public void doVerb(MessageIn message, int id) + public void doVerb(Message message) { - Truncation t = message.payload; - Tracing.trace("Applying truncation of {}.{}", t.keyspace, t.columnFamily); + TruncateRequest t = message.payload; + Tracing.trace("Applying truncation of {}.{}", t.keyspace, t.table); try { - ColumnFamilyStore cfs = Keyspace.open(t.keyspace).getColumnFamilyStore(t.columnFamily); + ColumnFamilyStore cfs = Keyspace.open(t.keyspace).getColumnFamilyStore(t.table); cfs.truncateBlocking(); } catch (Exception e) @@ -47,16 +49,16 @@ public void doVerb(MessageIn message, int id) if (FSError.findNested(e) != null) throw FSError.findNested(e); } - Tracing.trace("Enqueuing response to truncate operation to {}", message.from); + Tracing.trace("Enqueuing response to truncate operation to {}", message.from()); - TruncateResponse response = new TruncateResponse(t.keyspace, t.columnFamily, true); - logger.trace("{} applied. Enqueuing response to {}@{} ", t, id, message.from ); - MessagingService.instance().sendReply(response.createMessage(), id, message.from); + TruncateResponse response = new TruncateResponse(t.keyspace, t.table, true); + logger.trace("{} applied. Enqueuing response to {}@{} ", t, message.id(), message.from()); + MessagingService.instance().send(message.responseWith(response), message.from()); } - private static void respondError(Truncation t, MessageIn truncateRequestMessage) + private static void respondError(TruncateRequest t, Message truncateRequestMessage) { - TruncateResponse response = new TruncateResponse(t.keyspace, t.columnFamily, false); - MessagingService.instance().sendOneWay(response.createMessage(), truncateRequestMessage.from); + TruncateResponse response = new TruncateResponse(t.keyspace, t.table, false); + MessagingService.instance().send(truncateRequestMessage.responseWith(response), truncateRequestMessage.from()); } } diff --git a/src/java/org/apache/cassandra/db/Truncation.java b/src/java/org/apache/cassandra/db/Truncation.java deleted file mode 100644 index 39a2ec6b98a4..000000000000 --- a/src/java/org/apache/cassandra/db/Truncation.java +++ /dev/null @@ -1,74 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.cassandra.db; - -import java.io.IOException; - -import org.apache.cassandra.io.IVersionedSerializer; -import org.apache.cassandra.io.util.DataInputPlus; -import org.apache.cassandra.io.util.DataOutputPlus; -import org.apache.cassandra.net.MessageOut; -import org.apache.cassandra.net.MessagingService; - -/** - * A truncate operation descriptor - */ -public class Truncation -{ - public static final IVersionedSerializer serializer = new TruncationSerializer(); - - public final String keyspace; - public final String columnFamily; - - public Truncation(String keyspace, String columnFamily) - { - this.keyspace = keyspace; - this.columnFamily = columnFamily; - } - - public MessageOut createMessage() - { - return new MessageOut(MessagingService.Verb.TRUNCATE, this, serializer); - } - - public String toString() - { - return "Truncation(" + "keyspace='" + keyspace + '\'' + ", cf='" + columnFamily + "\')"; - } -} - -class TruncationSerializer implements IVersionedSerializer -{ - public void serialize(Truncation t, DataOutputPlus out, int version) throws IOException - { - out.writeUTF(t.keyspace); - out.writeUTF(t.columnFamily); - } - - public Truncation deserialize(DataInputPlus in, int version) throws IOException - { - String keyspace = in.readUTF(); - String columnFamily = in.readUTF(); - return new Truncation(keyspace, columnFamily); - } - - public long serializedSize(Truncation truncation, int version) - { - return TypeSizes.sizeof(truncation.keyspace) + TypeSizes.sizeof(truncation.columnFamily); - } -} diff --git a/src/java/org/apache/cassandra/db/columniterator/SSTableReversedIterator.java b/src/java/org/apache/cassandra/db/columniterator/SSTableReversedIterator.java index 60a6f70534e4..1e1030cf3dfe 100644 --- a/src/java/org/apache/cassandra/db/columniterator/SSTableReversedIterator.java +++ b/src/java/org/apache/cassandra/db/columniterator/SSTableReversedIterator.java @@ -108,7 +108,7 @@ protected ReusablePartitionData createBuffer(int blocksCount) // FIXME: so far we only keep stats on cells, so to get a rough estimate on the number of rows, // we divide by the number of regular columns the table has. We should fix once we collect the // stats on rows - int estimatedRowsPerPartition = (int)(sstable.getEstimatedColumnCount().percentile(0.75) / columnCount); + int estimatedRowsPerPartition = (int)(sstable.getEstimatedCellPerPartitionCount().percentile(0.75) / columnCount); estimatedRowCount = Math.max(estimatedRowsPerPartition / blocksCount, 1); } catch (IllegalStateException e) diff --git a/src/java/org/apache/cassandra/db/commitlog/AbstractCommitLogService.java b/src/java/org/apache/cassandra/db/commitlog/AbstractCommitLogService.java index b7ab70592c06..45df4eb46733 100644 --- a/src/java/org/apache/cassandra/db/commitlog/AbstractCommitLogService.java +++ b/src/java/org/apache/cassandra/db/commitlog/AbstractCommitLogService.java @@ -30,7 +30,7 @@ import org.apache.cassandra.concurrent.NamedThreadFactory; import org.apache.cassandra.config.Config; import org.apache.cassandra.db.commitlog.CommitLogSegment.Allocation; -import org.apache.cassandra.utils.Clock; +import org.apache.cassandra.utils.MonotonicClock; import org.apache.cassandra.utils.NoSpamLogger; import org.apache.cassandra.utils.concurrent.WaitQueue; @@ -133,21 +133,21 @@ void start() throw new IllegalArgumentException(String.format("Commit log flush interval must be positive: %fms", syncIntervalNanos * 1e-6)); shutdown = false; - Runnable runnable = new SyncRunnable(new Clock()); + Runnable runnable = new SyncRunnable(MonotonicClock.preciseTime); thread = NamedThreadFactory.createThread(runnable, name); thread.start(); } class SyncRunnable implements Runnable { - private final Clock clock; + private final MonotonicClock clock; private long firstLagAt = 0; private long totalSyncDuration = 0; // total time spent syncing since firstLagAt private long syncExceededIntervalBy = 0; // time that syncs exceeded pollInterval since firstLagAt private int lagCount = 0; private int syncCount = 0; - SyncRunnable(Clock clock) + SyncRunnable(MonotonicClock clock) { this.clock = clock; } @@ -169,7 +169,7 @@ boolean sync() try { // sync and signal - long pollStarted = clock.nanoTime(); + long pollStarted = clock.now(); boolean flushToDisk = lastSyncedAt + syncIntervalNanos <= pollStarted || shutdownRequested || syncRequested; if (flushToDisk) { @@ -186,7 +186,7 @@ boolean sync() commitLog.sync(false); } - long now = clock.nanoTime(); + long now = clock.now(); if (flushToDisk) maybeLogFlushLag(pollStarted, now); diff --git a/src/java/org/apache/cassandra/db/commitlog/CommitLog.java b/src/java/org/apache/cassandra/db/commitlog/CommitLog.java index 9d2a3695ca85..1b3fe2e3ad97 100644 --- a/src/java/org/apache/cassandra/db/commitlog/CommitLog.java +++ b/src/java/org/apache/cassandra/db/commitlog/CommitLog.java @@ -67,7 +67,7 @@ public class CommitLog implements CommitLogMBean // empty segments when writing large records final long MAX_MUTATION_SIZE = DatabaseDescriptor.getMaxMutationSize(); - final public AbstractCommitLogSegmentManager segmentManager; + final public CommitLogSegmentManager segmentManager; public final CommitLogArchiver archiver; final CommitLogMetrics metrics; @@ -108,9 +108,7 @@ private static CommitLog construct() throw new IllegalArgumentException("Unknown commitlog service type: " + DatabaseDescriptor.getCommitLogSync()); } - segmentManager = DatabaseDescriptor.isCDCEnabled() - ? new CommitLogSegmentManagerCDC(this, DatabaseDescriptor.getCommitLogLocation()) - : new CommitLogSegmentManagerStandard(this, DatabaseDescriptor.getCommitLogLocation()); + segmentManager = new CommitLogSegmentManager(this, DatabaseDescriptor.getCommitLogLocation()); // register metrics metrics.attach(executor, segmentManager); @@ -288,7 +286,7 @@ public CommitLogPosition add(Mutation mutation) throws CDCWriteException } catch (IOException e) { - throw new FSWriteError(e, segmentManager.allocatingFrom().getPath()); + throw new FSWriteError(e, segmentManager.getActiveSegment().getPath()); } } @@ -308,7 +306,7 @@ public void discardCompletedSegments(final TableId id, final CommitLogPosition l // flushed CF as clean, until we reach the segment file containing the CommitLogPosition passed // in the arguments. Any segments that become unused after they are marked clean will be // recycled or discarded. - for (Iterator iter = segmentManager.getActiveSegments().iterator(); iter.hasNext();) + for (Iterator iter = segmentManager.getSegmentsForUnflushedTables().iterator(); iter.hasNext();) { CommitLogSegment segment = iter.next(); segment.markClean(id, lowerBound, upperBound); @@ -364,7 +362,7 @@ public String getRestorePrecision() public List getActiveSegmentNames() { - Collection segments = segmentManager.getActiveSegments(); + Collection segments = segmentManager.getSegmentsForUnflushedTables(); List segmentNames = new ArrayList<>(segments.size()); for (CommitLogSegment seg : segments) segmentNames.add(seg.getName()); @@ -380,7 +378,7 @@ public List getArchivingSegmentNames() public long getActiveContentSize() { long size = 0; - for (CommitLogSegment seg : segmentManager.getActiveSegments()) + for (CommitLogSegment seg : segmentManager.getSegmentsForUnflushedTables()) size += seg.contentSize(); return size; } @@ -395,13 +393,14 @@ public long getActiveOnDiskSize() public Map getActiveSegmentCompressionRatios() { Map segmentRatios = new TreeMap<>(); - for (CommitLogSegment seg : segmentManager.getActiveSegments()) + for (CommitLogSegment seg : segmentManager.getSegmentsForUnflushedTables()) segmentRatios.put(seg.getName(), 1.0 * seg.onDiskSize() / seg.contentSize()); return segmentRatios; } /** * Shuts down the threads used by the commit log, blocking until completion. + * TODO this should accept a timeout, and throw TimeoutException */ public void shutdownBlocking() throws InterruptedException { diff --git a/src/java/org/apache/cassandra/db/commitlog/CommitLogDescriptor.java b/src/java/org/apache/cassandra/db/commitlog/CommitLogDescriptor.java index 700f12a242ca..2d04126ff3d5 100644 --- a/src/java/org/apache/cassandra/db/commitlog/CommitLogDescriptor.java +++ b/src/java/org/apache/cassandra/db/commitlog/CommitLogDescriptor.java @@ -62,7 +62,7 @@ public class CommitLogDescriptor public static final int VERSION_40 = 7; /** - * Increment this number if there is a changes in the commit log disc layout or MessagingVersion changes. + * Increment this number if there is a changes in the commit log disk layout or MessagingVersion changes. * Note: make sure to handle {@link #getMessagingVersion()} */ @VisibleForTesting diff --git a/src/java/org/apache/cassandra/db/commitlog/CommitLogReader.java b/src/java/org/apache/cassandra/db/commitlog/CommitLogReader.java index 078bb5304d41..7fb63bb27724 100644 --- a/src/java/org/apache/cassandra/db/commitlog/CommitLogReader.java +++ b/src/java/org/apache/cassandra/db/commitlog/CommitLogReader.java @@ -24,6 +24,8 @@ import java.util.concurrent.atomic.AtomicInteger; import java.util.zip.CRC32; +import javax.annotation.Nonnull; + import com.google.common.annotations.VisibleForTesting; import org.apache.commons.lang3.StringUtils; import org.slf4j.Logger; @@ -43,42 +45,96 @@ import org.apache.cassandra.io.util.RebufferingInputStream; import org.apache.cassandra.schema.TableId; import org.apache.cassandra.utils.JVMStabilityInspector; +import org.apache.cassandra.utils.Pair; import static org.apache.cassandra.utils.FBUtilities.updateChecksumInt; +/** + * The CommitLogReader presents an idempotentinterface for legacy CommitLogSegment reads. The logic to read and + * parse a CommitLogSegment is housed here, however this depends upon {@link ResumableCommitLogReader} for any non-trivial + * or resumable coordination of reads. + */ public class CommitLogReader { private static final Logger logger = LoggerFactory.getLogger(CommitLogReader.class); private static final int LEGACY_END_OF_SEGMENT_MARKER = 0; + /** Used to indicate we want to read to the end of a commit log segment during logic for resumable reading */ + static final int READ_TO_END_OF_FILE = Integer.MAX_VALUE; + @VisibleForTesting - public static final int ALL_MUTATIONS = -1; + static final int ALL_MUTATIONS = -1; private final CRC32 checksum; private final Map invalidMutations; + @Nonnull private byte[] buffer; - public CommitLogReader() + CommitLogReader() { checksum = new CRC32(); invalidMutations = new HashMap<>(); buffer = new byte[4096]; } - public Set> getInvalidMutations() + Set> getInvalidMutations() { return invalidMutations.entrySet(); } - /** - * Reads all passed in files with no minimum, no start, and no mutation limit. - */ + /** Reads all passed in files with no minimum, no start, and no mutation limit. */ public void readAllFiles(CommitLogReadHandler handler, File[] files) throws IOException { readAllFiles(handler, files, CommitLogPosition.NONE); } + /** Reads all passed in files with minPosition, no start, and no mutation limit. */ + public void readAllFiles(CommitLogReadHandler handler, File[] files, CommitLogPosition minPosition) throws IOException + { + List filteredLogs = filterCommitLogFiles(files); + int i = 0; + for (File file: filteredLogs) + { + i++; + readCommitLogSegment(handler, file, minPosition, ALL_MUTATIONS, i == filteredLogs.size()); + } + } + + /** Read a CommitLogSegment fully, no restrictions */ + void readCommitLogSegment(CommitLogReadHandler handler, File file, boolean tolerateTruncation) throws IOException + { + readCommitLogSegment(handler, file, CommitLogPosition.NONE, ALL_MUTATIONS, tolerateTruncation); + } + + /** Read passed in file fully, up to mutationLimit count */ + @VisibleForTesting + void readCommitLogSegment(CommitLogReadHandler handler, File file, int mutationLimit, boolean tolerateTruncation) throws IOException + { + readCommitLogSegment(handler, file, CommitLogPosition.NONE, mutationLimit, tolerateTruncation); + } + + /** Read all mutations from passed in file from minPosition in the logical CommitLog */ + void readCommitLogSegment(CommitLogReadHandler handler, File file, CommitLogPosition minPosition, boolean tolerateTruncation) throws IOException + { + readCommitLogSegment(handler, file, minPosition, ALL_MUTATIONS, tolerateTruncation); + } + + void readCommitLogSegment(CommitLogReadHandler handler, + File file, + CommitLogPosition minPosition, + int mutationLimit, + boolean tolerateTruncation) throws IOException + { + // TODO: Consider removing the need to create a resumable reader here and instead simply build and use the needed components. + // This would require a change to internalRead since that assumes a ResumableReader for convenience. + try(ResumableCommitLogReader resumableReader = new ResumableCommitLogReader(file, handler, minPosition, mutationLimit, tolerateTruncation)) + { + resumableReader.readToCompletion(); + } + } + + /** Confirms whether the passed in file is one we should read or skip based on whether it's empty and passes crc */ private static boolean shouldSkip(File file) throws IOException, ConfigurationException { try(RandomAccessReader reader = RandomAccessReader.open(file)) @@ -90,6 +146,7 @@ private static boolean shouldSkip(File file) throws IOException, ConfigurationEx } } + /** Filters list of passed in CommitLogSegments based on shouldSkip logic, specifically whether files are empty and pass crc. */ static List filterCommitLogFiles(File[] toFilter) { List filtered = new ArrayList<>(toFilter.length); @@ -117,163 +174,126 @@ static List filterCommitLogFiles(File[] toFilter) } /** - * Reads all passed in files with minPosition, no start, and no mutation limit. - */ - public void readAllFiles(CommitLogReadHandler handler, File[] files, CommitLogPosition minPosition) throws IOException - { - List filteredLogs = filterCommitLogFiles(files); - int i = 0; - for (File file: filteredLogs) - { - i++; - readCommitLogSegment(handler, file, minPosition, ALL_MUTATIONS, i == filteredLogs.size()); - } - } - - /** - * Reads passed in file fully - */ - public void readCommitLogSegment(CommitLogReadHandler handler, File file, boolean tolerateTruncation) throws IOException - { - readCommitLogSegment(handler, file, CommitLogPosition.NONE, ALL_MUTATIONS, tolerateTruncation); - } - - /** - * Reads all mutations from passed in file from minPosition - */ - public void readCommitLogSegment(CommitLogReadHandler handler, File file, CommitLogPosition minPosition, boolean tolerateTruncation) throws IOException - { - readCommitLogSegment(handler, file, minPosition, ALL_MUTATIONS, tolerateTruncation); - } - - /** - * Reads passed in file fully, up to mutationLimit count - */ - @VisibleForTesting - public void readCommitLogSegment(CommitLogReadHandler handler, File file, int mutationLimit, boolean tolerateTruncation) throws IOException - { - readCommitLogSegment(handler, file, CommitLogPosition.NONE, mutationLimit, tolerateTruncation); - } - - /** - * Reads mutations from file, handing them off to handler - * @param handler Handler that will take action based on deserialized Mutations - * @param file CommitLogSegment file to read - * @param minPosition Optional minimum CommitLogPosition - all segments with id larger or matching w/greater position will be read - * @param mutationLimit Optional limit on # of mutations to replay. Local ALL_MUTATIONS serves as marker to play all. - * @param tolerateTruncation Whether or not we should allow truncation of this file or throw if EOF found + * Reads and constructs the {@link CommitLogDescriptor} portion of a File. * - * @throws IOException + * @return Pair, Integer> An optional descriptor and serialized header size */ - public void readCommitLogSegment(CommitLogReadHandler handler, - File file, - CommitLogPosition minPosition, - int mutationLimit, - boolean tolerateTruncation) throws IOException + static Pair, Integer> readCommitLogDescriptor(CommitLogReadHandler handler, + File file, + boolean tolerateTruncation) throws IOException { // just transform from the file name (no reading of headers) to determine version CommitLogDescriptor desc = CommitLogDescriptor.fromFileName(file.getName()); + long segmentIdFromFilename = desc.id; + int descriptorSize = -1; - try(RandomAccessReader reader = RandomAccessReader.open(file)) + try(RandomAccessReader rawSegmentReader = RandomAccessReader.open(file)) { - final long segmentIdFromFilename = desc.id; try { // The following call can either throw or legitimately return null. For either case, we need to check // desc outside this block and set it to null in the exception case. - desc = CommitLogDescriptor.readHeader(reader, DatabaseDescriptor.getEncryptionContext()); + desc = CommitLogDescriptor.readHeader(rawSegmentReader, DatabaseDescriptor.getEncryptionContext()); } catch (Exception e) { desc = null; } + if (desc == null) { - // don't care about whether or not the handler thinks we can continue. We can't w/out descriptor. - // whether or not we can continue depends on whether this is the last segment - handler.handleUnrecoverableError(new CommitLogReadException( + // Don't care about whether or not the handler thinks we can continue. We can't w/out descriptor. + // Whether or not we can continue depends on whether this is the last segment + handler.handleUnrecoverableError(new CommitLogReadHandler.CommitLogReadException( String.format("Could not read commit log descriptor in file %s", file), - CommitLogReadErrorReason.UNRECOVERABLE_DESCRIPTOR_ERROR, + CommitLogReadHandler.CommitLogReadErrorReason.UNRECOVERABLE_DESCRIPTOR_ERROR, tolerateTruncation)); - return; + return Pair.create(Optional.empty(), -1); } + // Continuing if our file name and descriptor mismatch is optional. if (segmentIdFromFilename != desc.id) { - if (handler.shouldSkipSegmentOnError(new CommitLogReadException(String.format( - "Segment id mismatch (filename %d, descriptor %d) in file %s", segmentIdFromFilename, desc.id, file), - CommitLogReadErrorReason.RECOVERABLE_DESCRIPTOR_ERROR, - false))) + if (handler.shouldSkipSegmentOnError(new CommitLogReadHandler.CommitLogReadException( + String.format("Segment id mismatch (filename %d, descriptor %d) in file %s", segmentIdFromFilename, desc.id, file), + CommitLogReadHandler.CommitLogReadErrorReason.RECOVERABLE_DESCRIPTOR_ERROR, + false))) { - return; + return Pair.create(Optional.empty(), -1); } } - if (shouldSkipSegmentId(file, desc, minPosition)) - return; - - CommitLogSegmentReader segmentReader; - try - { - segmentReader = new CommitLogSegmentReader(handler, desc, reader, tolerateTruncation); - } - catch(Exception e) - { - handler.handleUnrecoverableError(new CommitLogReadException( - String.format("Unable to create segment reader for commit log file: %s", e), - CommitLogReadErrorReason.UNRECOVERABLE_UNKNOWN_ERROR, - tolerateTruncation)); - return; - } - - try - { - ReadStatusTracker statusTracker = new ReadStatusTracker(mutationLimit, tolerateTruncation); - for (CommitLogSegmentReader.SyncSegment syncSegment : segmentReader) - { - // Only tolerate truncation if we allow in both global and segment - statusTracker.tolerateErrorsInSection = tolerateTruncation & syncSegment.toleratesErrorsInSection; - - // Skip segments that are completely behind the desired minPosition - if (desc.id == minPosition.segmentId && syncSegment.endPosition < minPosition.position) - continue; - - statusTracker.errorContext = String.format("Next section at %d in %s", syncSegment.fileStartPosition, desc.fileName()); + descriptorSize = (int)rawSegmentReader.getPosition(); + } + return Pair.create(Optional.of(desc), descriptorSize); + } - readSection(handler, syncSegment.input, minPosition, syncSegment.endPosition, statusTracker, desc); - if (!statusTracker.shouldContinue()) - break; - } - } - // Unfortunately AbstractIterator cannot throw a checked exception, so we check to see if a RuntimeException - // is wrapping an IOException. - catch (RuntimeException re) - { - if (re.getCause() instanceof IOException) - throw (IOException) re.getCause(); - throw re; - } - logger.debug("Finished reading {}", file); + /** + * Opens a RandomAccessReader to a CommitLogSegment _and does not close it_. Closed out in {@link ResumableCommitLogReader#close} + */ + static CommitLogSegmentReader getCommitLogSegmentReader(ResumableCommitLogReader parent) throws IOException + { + CommitLogSegmentReader result; + try + { + result = new CommitLogSegmentReader(parent); } + catch(Exception e) + { + parent.readHandler.handleUnrecoverableError(new CommitLogReadHandler.CommitLogReadException( + String.format("Unable to create segment reader for commit log file: %s", e), + CommitLogReadHandler.CommitLogReadErrorReason.UNRECOVERABLE_UNKNOWN_ERROR, + parent.tolerateTruncation)); + // Regardless of whether this is in the node context and we allow the node to continue to run, this reader is + // dead. + parent.close(); + return null; + } + return result; } /** - * Any segment with id >= minPosition.segmentId is a candidate for read. + * Iterates over {@link CommitLogSegmentReader.SyncSegment} until it hits offset limit or end of iterator, based + * on the resumable reader's sentinel. */ - private boolean shouldSkipSegmentId(File file, CommitLogDescriptor desc, CommitLogPosition minPosition) + void internalReadCommitLogSegment(ResumableCommitLogReader rr) throws IOException { - logger.debug("Reading {} (CL version {}, messaging version {}, compression {})", - file.getPath(), - desc.version, - desc.getMessagingVersion(), - desc.compression); + try + { + ReadStatusTracker statusTracker = new ReadStatusTracker(rr.mutationLimit, rr.tolerateTruncation); + + int lastSegmentEnd = -1; + + while (lastSegmentEnd < rr.offsetLimit && rr.activeIterator.hasNext()) + { + CommitLogSegmentReader.SyncSegment syncSegment = rr.activeIterator.next(); + // Back out if we're at the end of our current partially written CL segment. + if (syncSegment == CommitLogSegmentReader.RESUMABLE_SENTINEL) + break; + + lastSegmentEnd = syncSegment.endPosition; + + statusTracker.tolerateErrorsInSection = rr.tolerateTruncation & syncSegment.toleratesErrorsInSection; - if (minPosition.segmentId > desc.id) + // Skip segments that are completely behind the desired minPosition + if (rr.descriptor.id == rr.minPosition.segmentId && syncSegment.endPosition < rr.minPosition.position) + continue; + + statusTracker.errorContext = String.format("Next section at %d in %s", syncSegment.fileStartPosition, rr.descriptor.fileName()); + + readSection(rr.readHandler, syncSegment.input, rr.minPosition, syncSegment.endPosition, statusTracker, rr.descriptor); + if (!statusTracker.shouldContinue()) + break; + } + } + // Unfortunately AbstractIterator cannot throw a checked exception, so we check to see if a RuntimeException + // is wrapping an IOException. + catch (RuntimeException | IOException re) { - logger.trace("Skipping read of fully-flushed {}", file); - return true; + if (re.getCause() instanceof IOException) + throw (IOException) re.getCause(); + throw re; } - return false; } /** @@ -403,7 +423,7 @@ private void readSection(CommitLogReadHandler handler, } /** - * Deserializes and passes a Mutation to the ICommitLogReadHandler requested + * Deserializes and passes a Mutation to the CommitLogReadHandler requested * * @param handler Handler that will take action based on deserialized Mutations * @param inputBuffer raw byte array w/Mutation data @@ -482,48 +502,51 @@ protected void readMutation(CommitLogReadHandler handler, */ private static class CommitLogFormat { - public static long calculateClaimedChecksum(FileDataInput input, int commitLogVersion) throws IOException + static long calculateClaimedChecksum(FileDataInput input, int commitLogVersion) throws IOException { return input.readInt() & 0xffffffffL; } - public static void updateChecksum(CRC32 checksum, int serializedSize, int commitLogVersion) + static void updateChecksum(CRC32 checksum, int serializedSize, int commitLogVersion) { updateChecksumInt(checksum, serializedSize); } - public static long calculateClaimedCRC32(FileDataInput input, int commitLogVersion) throws IOException + static long calculateClaimedCRC32(FileDataInput input, int commitLogVersion) throws IOException { return input.readInt() & 0xffffffffL; } } + /** + * Caches the state needed for decision-making on multiple CommitLog Read operations. Used internally in the CommitLogReader + */ private static class ReadStatusTracker { private int mutationsLeft; - public String errorContext = ""; - public boolean tolerateErrorsInSection; + String errorContext = ""; + boolean tolerateErrorsInSection; private boolean error; - public ReadStatusTracker(int mutationLimit, boolean tolerateErrorsInSection) + ReadStatusTracker(int mutationLimit, boolean tolerateErrorsInSection) { this.mutationsLeft = mutationLimit; this.tolerateErrorsInSection = tolerateErrorsInSection; } - public void addProcessedMutation() + void addProcessedMutation() { if (mutationsLeft == ALL_MUTATIONS) return; --mutationsLeft; } - public boolean shouldContinue() + boolean shouldContinue() { return !error && (mutationsLeft != 0 || mutationsLeft == ALL_MUTATIONS); } - public void requestTermination() + void requestTermination() { error = true; } diff --git a/src/java/org/apache/cassandra/db/commitlog/CommitLogReplayer.java b/src/java/org/apache/cassandra/db/commitlog/CommitLogReplayer.java index 2947222e4d93..f63499ab23c0 100644 --- a/src/java/org/apache/cassandra/db/commitlog/CommitLogReplayer.java +++ b/src/java/org/apache/cassandra/db/commitlog/CommitLogReplayer.java @@ -167,8 +167,8 @@ public void replayFiles(File[] clogs) throws IOException */ private void handleCDCReplayCompletion(File f) throws IOException { - // Can only reach this point if CDC is enabled, thus we have a CDCSegmentManager - ((CommitLogSegmentManagerCDC)CommitLog.instance.segmentManager).addCDCSize(f.length()); + // Can only reach this point if CDC is enabled, thus we have a CDC Allocator + ((CommitLogSegmentAllocatorCDC)CommitLog.instance.segmentManager.segmentAllocator).addCDCSize(f.length()); File dest = new File(DatabaseDescriptor.getCDCLogLocation(), f.getName()); @@ -217,7 +217,7 @@ public int blockForWrites() // also flush batchlog incase of any MV updates if (!flushingSystem) - futures.add(Keyspace.open(SchemaConstants.SYSTEM_KEYSPACE_NAME).getColumnFamilyStore(SystemKeyspace.BATCHES).forceFlush()); + futures.add(Keyspace.open(SchemaConstants.SYSTEM_KEYSPACE_NAME).getColumnFamilyStore(SystemKeyspace.BATCHES).forceFlushToSSTable()); FBUtilities.waitOnFutures(futures); diff --git a/src/java/org/apache/cassandra/db/commitlog/CommitLogSegment.java b/src/java/org/apache/cassandra/db/commitlog/CommitLogSegment.java index 5303de9da40e..793605b860a5 100644 --- a/src/java/org/apache/cassandra/db/commitlog/CommitLogSegment.java +++ b/src/java/org/apache/cassandra/db/commitlog/CommitLogSegment.java @@ -32,6 +32,9 @@ import com.google.common.annotations.VisibleForTesting; import org.cliffc.high_scale_lib.NonBlockingHashMap; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + import com.codahale.metrics.Timer; import org.apache.cassandra.config.*; import org.apache.cassandra.db.Mutation; @@ -49,13 +52,15 @@ import static org.apache.cassandra.utils.FBUtilities.updateChecksumInt; -/* +/** * A single commit log file on disk. Manages creation of the file and writing mutations to disk, * as well as tracking the last mutation position of any "dirty" CFs covered by the segment file. Segment * files are initially allocated to a fixed size and can grow to accomidate a larger value if necessary. */ public abstract class CommitLogSegment { + static final Logger logger = LoggerFactory.getLogger(CommitLogSegmentManager.class); + private final static long idBase; private CDCState cdcState = CDCState.PERMITTED; @@ -65,7 +70,7 @@ public enum CDCState FORBIDDEN, CONTAINS } - Object cdcStateLock = new Object(); + final Object cdcStateLock = new Object(); private final static AtomicInteger nextId = new AtomicInteger(1); private static long replayLimitId; @@ -80,20 +85,21 @@ public enum CDCState replayLimitId = idBase = Math.max(System.currentTimeMillis(), maxId + 1); } - // The commit log entry overhead in bytes (int: length + int: head checksum + int: tail checksum) - public static final int ENTRY_OVERHEAD_SIZE = 4 + 4 + 4; + /** The commit log entry overhead in bytes (int: length + int: head checksum + int: tail checksum) */ + static final int ENTRY_OVERHEAD_SIZE = 4 + 4 + 4; - // The commit log (chained) sync marker/header size in bytes (int: length + int: checksum [segmentId, position]) + /** The commit log (chained) sync marker/header size in bytes (int: length + int: checksum [segmentId, position]) */ static final int SYNC_MARKER_SIZE = 4 + 4; - // The OpOrder used to order appends wrt sync + /** The OpOrder used to order appends wrt sync */ private final OpOrder appendOrder = new OpOrder(); private final AtomicInteger allocatePosition = new AtomicInteger(); - // Everything before this offset has been synced and written. The SYNC_MARKER_SIZE bytes after - // each sync are reserved, and point forwards to the next such offset. The final - // sync marker in a segment will be zeroed out, or point to a position too close to the EOF to fit a marker. + /** Everything before this offset has been synced and written. The SYNC_MARKER_SIZE bytes after + * each sync are reserved, and point forwards to the next such offset. The final + * sync marker in a segment will be zeroed out, or point to a position too close to the EOF to fit a marker. + */ @VisibleForTesting volatile int lastSyncedOffset; @@ -103,34 +109,36 @@ public enum CDCState */ private volatile int lastMarkerOffset; - // The end position of the buffer. Initially set to its capacity and updated to point to the last written position - // as the segment is being closed. - // No need to be volatile as writes are protected by appendOrder barrier. + /** The end position of the buffer. Initially set to its capacity and updated to point to the last written position + * as the segment is being closed. + * No need to be volatile as writes are protected by appendOrder barrier. + */ private int endOfBuffer; - // a signal for writers to wait on to confirm the log message they provided has been written to disk + /** a signal for writers to wait on to confirm the log message they provided has been written to disk */ private final WaitQueue syncComplete = new WaitQueue(); - // a map of Cf->dirty interval in this segment; if interval is not covered by the clean set, the log contains unflushed data + /** a map of Cf->dirty interval in this segment; if interval is not covered by the clean set, the log contains unflushed data */ private final NonBlockingHashMap tableDirty = new NonBlockingHashMap<>(1024); - // a map of Cf->clean intervals; separate map from above to permit marking Cfs clean whilst the log is still in use + /** a map of Cf->clean intervals; separate map from above to permit marking Cfs clean whilst the log is still in use */ private final ConcurrentHashMap tableClean = new ConcurrentHashMap<>(); public final long id; + /** The CommitLogSegment log file on disk */ final File logFile; final FileChannel channel; final int fd; - protected final AbstractCommitLogSegmentManager manager; + protected final CommitLogSegmentManager manager; ByteBuffer buffer; private volatile boolean headerWritten; public final CommitLogDescriptor descriptor; - static CommitLogSegment createSegment(CommitLog commitLog, AbstractCommitLogSegmentManager manager) + static CommitLogSegment createSegment(CommitLog commitLog, CommitLogSegmentManager manager) { Configuration config = commitLog.configuration; CommitLogSegment segment = config.useEncryption() ? new EncryptedSegment(commitLog, manager) @@ -160,7 +168,7 @@ static long getNextId() /** * Constructs a new segment file. */ - CommitLogSegment(CommitLog commitLog, AbstractCommitLogSegmentManager manager) + CommitLogSegment(CommitLog commitLog, CommitLogSegmentManager manager) { this.manager = manager; @@ -265,10 +273,10 @@ private int allocate(int size) } } - // ensures no more of this segment is writeable, by allocating any unused section at the end and marking it discarded + // Ensures no more of this segment is writeable, by allocating any unused section at the end and marking it discarded. void discardUnusedTail() { - // We guard this with the OpOrdering instead of synchronised due to potential dead-lock with ACLSM.advanceAllocatingFrom() + // We guard this with the OpOrdering instead of synchronised due to potential dead-lock with CLSM.switchToNewSegment() // Ensures endOfBuffer update is reflected in the buffer end position picked up by sync(). // This actually isn't strictly necessary, as currently all calls to discardUnusedTail are executed either by the thread // running sync or within a mutation already protected by this OpOrdering, but to prevent future potential mistakes, @@ -310,6 +318,8 @@ void waitForModifications() * Update the chained markers in the commit log buffer and possibly force a disk flush for this segment file. * * @param flush true if the segment should flush to disk; else, false for just updating the chained markers. + * Named such to disambiguate whether we're looking to flush associated memtables to disk + * or just this CL segment. */ synchronized void sync(boolean flush) { @@ -363,7 +373,6 @@ synchronized void sync(boolean flush) sectionEnd = nextMarker; } - if (flush || close) { flush(startMarker, sectionEnd); @@ -383,7 +392,7 @@ synchronized void sync(boolean flush) * in shared / memory-mapped buffers reflects un-synced data so we need an external sentinel for clients to read to * determine actual durable data persisted. */ - public static void writeCDCIndexFile(CommitLogDescriptor desc, int offset, boolean complete) + static void writeCDCIndexFile(CommitLogDescriptor desc, int offset, boolean complete) { try(FileWriter writer = new FileWriter(new File(DatabaseDescriptor.getCDCLogLocation(), desc.cdcIndexFileName()))) { @@ -425,7 +434,7 @@ protected static void writeSyncMarker(long id, ByteBuffer buffer, int offset, in abstract void flush(int startMarker, int nextMarker); - public boolean isStillAllocating() + boolean hasRoom() { return allocatePosition.get() < endOfBuffer; } @@ -445,7 +454,7 @@ void discard(boolean deleteFile) /** * @return the current CommitLogPosition for this log segment */ - public CommitLogPosition getCurrentCommitLogPosition() + CommitLogPosition getCurrentCommitLogPosition() { return new CommitLogPosition(id, allocatePosition.get()); } @@ -469,7 +478,7 @@ public String getName() /** * @return a File object representing the CDC directory and this file name for hard-linking */ - public File getCDCFile() + File getCDCFile() { return new File(DatabaseDescriptor.getCDCLogLocation(), logFile.getName()); } @@ -477,7 +486,7 @@ public File getCDCFile() /** * @return a File object representing the CDC Index file holding the offset and completion status of this segment */ - public File getCDCIndexFile() + File getCDCIndexFile() { return new File(DatabaseDescriptor.getCDCLogLocation(), descriptor.cdcIndexFileName()); } @@ -539,7 +548,7 @@ protected void internalClose() } } - public static void coverInMap(ConcurrentMap map, K key, int value) + private static void coverInMap(ConcurrentMap map, K key, int value) { IntegerInterval i = map.get(key); if (i == null) @@ -561,7 +570,7 @@ public static void coverInMap(ConcurrentMap map, K key, i * @param startPosition the start of the range that is clean * @param endPosition the end of the range that is clean */ - public synchronized void markClean(TableId tableId, CommitLogPosition startPosition, CommitLogPosition endPosition) + synchronized void markClean(TableId tableId, CommitLogPosition startPosition, CommitLogPosition endPosition) { if (startPosition.segmentId > id || endPosition.segmentId < id) return; @@ -576,7 +585,7 @@ public synchronized void markClean(TableId tableId, CommitLogPosition startPosit private void removeCleanFromDirty() { // if we're still allocating from this segment, don't touch anything since it can't be done thread-safely - if (isStillAllocating()) + if (hasRoom()) return; Iterator> iter = tableClean.entrySet().iterator(); @@ -619,9 +628,9 @@ public synchronized Collection getDirtyTableIds() */ public synchronized boolean isUnused() { - // if room to allocate, we're still in use as the active allocatingFrom, + // if room to allocate, we're still in use as the active segment, // so we don't want to race with updates to tableClean with removeCleanFromDirty - if (isStillAllocating()) + if (hasRoom()) return false; removeCleanFromDirty(); @@ -640,7 +649,7 @@ public boolean contains(CommitLogPosition context) } // For debugging, not fast - public String dirtyString() + String dirtyString() { StringBuilder sb = new StringBuilder(); for (TableId tableId : getDirtyTableIds()) @@ -656,7 +665,7 @@ public String dirtyString() abstract public long onDiskSize(); - public long contentSize() + long contentSize() { return lastSyncedOffset; } @@ -677,7 +686,7 @@ public int compare(File f, File f2) } } - public CDCState getCDCState() + CDCState getCDCState() { return cdcState; } @@ -686,7 +695,7 @@ public CDCState getCDCState() * Change the current cdcState on this CommitLogSegment. There are some restrictions on state transitions and this * method is idempotent. */ - public void setCDCState(CDCState newState) + void setCDCState(CDCState newState) { if (newState == cdcState) return; @@ -748,7 +757,7 @@ void awaitDiskSync(Timer waitingOnCommit) /** * Returns the position in the CommitLogSegment at the end of this allocation. */ - public CommitLogPosition getCommitLogPosition() + CommitLogPosition getCommitLogPosition() { return new CommitLogPosition(segment.id, buffer.limit()); } diff --git a/src/java/org/apache/cassandra/db/commitlog/CommitLogSegmentAllocator.java b/src/java/org/apache/cassandra/db/commitlog/CommitLogSegmentAllocator.java new file mode 100644 index 000000000000..883b825be13e --- /dev/null +++ b/src/java/org/apache/cassandra/db/commitlog/CommitLogSegmentAllocator.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.db.commitlog; + +import java.io.File; + +import org.apache.cassandra.db.Mutation; + +public interface CommitLogSegmentAllocator +{ + void start(); + void shutdown(); + + /** + * Indicates that a segment file has been flushed and is no longer needed. This can perform blocking disk + * operations so use with caution in critical path. + * + * @param segment segment to be discarded + * @param delete whether or not the segment is safe to be deleted. + */ + void discard(CommitLogSegment segment, boolean delete); + + /** + * Allocate a segment. This is always expected to succeed so should throw some form of exception on failure to + * allocate; if you can't allocate a CLS, you can no longer write and the node is in a bad state. + */ + CommitLogSegment.Allocation allocate(Mutation mutation, int size); + + /** + * Hook to allow segment managers to track state surrounding creation of new segments. This method is called + * on a separate segment management thread instead of the critical path so longer-running operations are acceptable. + */ + CommitLogSegment createSegment(); + + /** + * When segments complete replay, the allocator has a hook to take action at that time. + */ + void handleReplayedSegment(final File file); +} diff --git a/src/java/org/apache/cassandra/db/commitlog/CommitLogSegmentManagerCDC.java b/src/java/org/apache/cassandra/db/commitlog/CommitLogSegmentAllocatorCDC.java similarity index 72% rename from src/java/org/apache/cassandra/db/commitlog/CommitLogSegmentManagerCDC.java rename to src/java/org/apache/cassandra/db/commitlog/CommitLogSegmentAllocatorCDC.java index bdd4f74ae1e6..bd96a23d47b2 100644 --- a/src/java/org/apache/cassandra/db/commitlog/CommitLogSegmentManagerCDC.java +++ b/src/java/org/apache/cassandra/db/commitlog/CommitLogSegmentAllocatorCDC.java @@ -39,28 +39,39 @@ import org.apache.cassandra.utils.DirectorySizeCalculator; import org.apache.cassandra.utils.NoSpamLogger; -public class CommitLogSegmentManagerCDC extends AbstractCommitLogSegmentManager +/** + * A CommitLogSegmentAllocator that respects the configured total allowable CDC space on disk. On allocation of a mutation + * checks if it's on a table tracked by CDC and, if so, either throws an exception if at CDC limit or flags that segment + * as containing a CDC mutation if it's a new one. + * + * This code path is only exercised if cdc is enabled on a node. We pay the duplication cost of having both CDC and non + * allocators in order to keep the old allocator code clean and separate from this allocator, as well as to not introduce + * unnecessary operations on the critical path for nodes / users where they have no interest in CDC. May be worth considering + * unifying in the future should the perf implications of this be shown to be negligible, though the hard linking and + * size tracking is somewhat distasteful to have floating around on nodes where cdc is not in use (which we assume to be + * the majority). + */ +public class CommitLogSegmentAllocatorCDC implements CommitLogSegmentAllocator { - static final Logger logger = LoggerFactory.getLogger(CommitLogSegmentManagerCDC.class); + static final Logger logger = LoggerFactory.getLogger(CommitLogSegmentAllocatorCDC.class); private final CDCSizeTracker cdcSizeTracker; + private final CommitLogSegmentManager segmentManager; - public CommitLogSegmentManagerCDC(final CommitLog commitLog, String storageDirectory) + CommitLogSegmentAllocatorCDC(CommitLogSegmentManager segmentManager) { - super(commitLog, storageDirectory); - cdcSizeTracker = new CDCSizeTracker(this, new File(DatabaseDescriptor.getCDCLogLocation())); + this.segmentManager = segmentManager; + cdcSizeTracker = new CDCSizeTracker(segmentManager, new File(DatabaseDescriptor.getCDCLogLocation())); } - @Override - void start() + public void start() { cdcSizeTracker.start(); - super.start(); } public void discard(CommitLogSegment segment, boolean delete) { segment.close(); - addSize(-segment.onDiskSize()); + segmentManager.addSize(-segment.onDiskSize()); cdcSizeTracker.processDiscardedSegment(segment); @@ -82,12 +93,11 @@ public void discard(CommitLogSegment segment, boolean delete) } /** - * Initiates the shutdown process for the management thread. Also stops the cdc on-disk size calculator executor. + * Stops the thread pool for CDC on disk size tracking. */ public void shutdown() { cdcSizeTracker.shutdown(); - super.shutdown(); } /** @@ -99,20 +109,23 @@ public void shutdown() * @return the created Allocation object * @throws CDCWriteException If segment disallows CDC mutations, we throw */ - @Override public CommitLogSegment.Allocation allocate(Mutation mutation, int size) throws CDCWriteException { - CommitLogSegment segment = allocatingFrom(); - CommitLogSegment.Allocation alloc; - + CommitLogSegment segment = segmentManager.getActiveSegment(); throwIfForbidden(mutation, segment); - while ( null == (alloc = segment.allocate(mutation, size)) ) + + CommitLogSegment.Allocation alloc = segment.allocate(mutation, size); + // If we failed to allocate in the segment, prompt for a switch to a new segment and loop on re-attempt. This + // is expected to succeed or throw, since CommitLog allocation working is central to how a node operates. + while (alloc == null) { // Failed to allocate, so move to a new segment with enough room if possible. - advanceAllocatingFrom(segment); - segment = allocatingFrom(); + segmentManager.switchToNewSegment(segment); + segment = segmentManager.getActiveSegment(); + // New segment, so confirm whether or not CDC mutations are allowed on this. throwIfForbidden(mutation, segment); + alloc = segment.allocate(mutation, size); } if (mutation.trackedByCDC()) @@ -143,7 +156,7 @@ private void throwIfForbidden(Mutation mutation, CommitLogSegment segment) throw */ public CommitLogSegment createSegment() { - CommitLogSegment segment = CommitLogSegment.createSegment(commitLog, this); + CommitLogSegment segment = CommitLogSegment.createSegment(segmentManager.commitLog, segmentManager); // Hard link file in cdc folder for realtime tracking FileUtils.createHardLink(segment.logFile, segment.getCDCFile()); @@ -157,11 +170,8 @@ public CommitLogSegment createSegment() * * @param file segment file that is no longer in use. */ - @Override - void handleReplayedSegment(final File file) + public void handleReplayedSegment(final File file) { - super.handleReplayedSegment(file); - // delete untracked cdc segment hard link files if their index files do not exist File cdcFile = new File(DatabaseDescriptor.getCDCLogLocation(), file.getName()); File cdcIndexFile = new File(DatabaseDescriptor.getCDCLogLocation(), CommitLogDescriptor.fromFileName(file.getName()).cdcIndexFileName()); @@ -175,7 +185,7 @@ void handleReplayedSegment(final File file) /** * For use after replay when replayer hard-links / adds tracking of replayed segments */ - public void addCDCSize(long size) + void addCDCSize(long size) { cdcSizeTracker.addSize(size); } @@ -185,18 +195,18 @@ public void addCDCSize(long size) * data in them and all segments archived into cdc_raw. * * Allows atomic increment/decrement of unflushed size, however only allows increment on flushed and requires a full - * directory walk to determine any potential deletions by CDC consumer. + * directory walk to determine any potential deletions by an external CDC consumer. */ private static class CDCSizeTracker extends DirectorySizeCalculator { private final RateLimiter rateLimiter = RateLimiter.create(1000.0 / DatabaseDescriptor.getCDCDiskCheckInterval()); private ExecutorService cdcSizeCalculationExecutor; - private CommitLogSegmentManagerCDC segmentManager; + private final CommitLogSegmentManager segmentManager; - // Used instead of size during walk to remove chance of over-allocation + /** Used only in context of file tree walking thread; not read nor mutated outside this context */ private volatile long sizeInProgress = 0; - CDCSizeTracker(CommitLogSegmentManagerCDC segmentManager, File path) + CDCSizeTracker(CommitLogSegmentManager segmentManager, File path) { super(path); this.segmentManager = segmentManager; @@ -215,9 +225,9 @@ public void start() * Synchronous size recalculation on each segment creation/deletion call could lead to very long delays in new * segment allocation, thus long delays in thread signaling to wake waiting allocation / writer threads. * - * This can be reached either from the segment management thread in ABstractCommitLogSegmentManager or from the + * This can be reached either from the segment management thread in CommitLogSegmentManager or from the * size recalculation executor, so we synchronize on this object to reduce the race overlap window available for - * size to get off. + * size to drift. * * Reference DirectorySizerBench for more information about performance of the directory size recalc. */ @@ -237,6 +247,10 @@ void processNewSegment(CommitLogSegment segment) submitOverflowSizeRecalculation(); } + /** + * Upon segment discard, we need to adjust our known CDC consumption on disk based on whether or not this segment + * was flagged to be allowable for CDC. + */ void processDiscardedSegment(CommitLogSegment segment) { // See synchronization in CommitLogSegment.setCDCState @@ -258,7 +272,13 @@ private long allowableCDCBytes() return (long)DatabaseDescriptor.getCDCSpaceInMB() * 1024 * 1024; } - public void submitOverflowSizeRecalculation() + /** + * The overflow size calculation requires walking the flie tree and checking file size for all linked CDC + * files. As such, we do this async on the executor in the CDCSizeTracker instead of the context of the calling + * thread. While this can obviously introduce some delay / raciness in the calculation of CDC size consumed, + * the alternative of significantly long blocks for critical path CL allocation is unacceptable. + */ + void submitOverflowSizeRecalculation() { try { @@ -274,9 +294,12 @@ private void recalculateOverflowSize() { rateLimiter.acquire(); calculateSize(); - CommitLogSegment allocatingFrom = segmentManager.allocatingFrom(); - if (allocatingFrom.getCDCState() == CDCState.FORBIDDEN) - processNewSegment(allocatingFrom); + CommitLogSegment activeCommitLogSegment = segmentManager.getActiveSegment(); + // In the event that the current segment is disallowed for CDC, re-check it as our size on disk may have + // reduced, thus allowing the segment to accept CDC writes. It's worth noting: this would spin on recalc + // endlessly if not for the rate limiter dropping looping calls on the floor. + if (activeCommitLogSegment.getCDCState() == CDCState.FORBIDDEN) + processNewSegment(activeCommitLogSegment); } private int defaultSegmentSize() @@ -288,14 +311,14 @@ private void calculateSize() { try { - // The Arrays.stream approach is considerably slower on Windows than linux + // The Arrays.stream approach is considerably slower sizeInProgress = 0; Files.walkFileTree(path.toPath(), this); size = sizeInProgress; } catch (IOException ie) { - CommitLog.instance.handleCommitError("Failed CDC Size Calculation", ie); + CommitLog.handleCommitError("Failed CDC Size Calculation", ie); } } @@ -327,7 +350,7 @@ private long totalCDCSizeOnDisk() * Only use for testing / validation that size tracker is working. Not for production use. */ @VisibleForTesting - public long updateCDCTotalSize() + long updateCDCTotalSize() { cdcSizeTracker.submitOverflowSizeRecalculation(); @@ -336,7 +359,9 @@ public long updateCDCTotalSize() { Thread.sleep(DatabaseDescriptor.getCDCDiskCheckInterval() + 10); } - catch (InterruptedException e) {} + catch (InterruptedException e) { + // Expected in test context. no-op. + } return cdcSizeTracker.totalCDCSizeOnDisk(); } diff --git a/src/java/org/apache/cassandra/db/commitlog/CommitLogSegmentAllocatorStandard.java b/src/java/org/apache/cassandra/db/commitlog/CommitLogSegmentAllocatorStandard.java new file mode 100644 index 000000000000..f3d64bfbe922 --- /dev/null +++ b/src/java/org/apache/cassandra/db/commitlog/CommitLogSegmentAllocatorStandard.java @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.db.commitlog; + +import java.io.File; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.cassandra.db.Mutation; +import org.apache.cassandra.io.util.FileUtils; + +/** + * This is a fairly simple form of a CommitLogSegmentAllocator. + */ +public class CommitLogSegmentAllocatorStandard implements CommitLogSegmentAllocator +{ + static final Logger logger = LoggerFactory.getLogger(CommitLogSegmentAllocatorStandard.class); + private final CommitLogSegmentManager segmentManager; + + public void start() {} + public void shutdown() {} + + CommitLogSegmentAllocatorStandard(CommitLogSegmentManager segmentManager) { + this.segmentManager = segmentManager; + } + + /** + * No extra processing required beyond deletion of the file once we have replayed it. + */ + public void handleReplayedSegment(final File file) { + // (don't decrease managed size, since this was never a "live" segment) + logger.trace("(Unopened) segment {} is no longer needed and will be deleted now", file); + FileUtils.deleteWithConfirm(file); + } + + public void discard(CommitLogSegment segment, boolean delete) + { + segment.close(); + if (delete) + FileUtils.deleteWithConfirm(segment.logFile); + segmentManager.addSize(-segment.onDiskSize()); + } + + /** + * Reserve space in the current segment for the provided mutation or, if there isn't space available, + * create a new segment. allocate() is blocking until allocation succeeds as it waits on a signal in switchToNewSegment + * + * @param mutation mutation to allocate space for + * @param size total size of mutation (overhead + serialized size) + * @return the provided Allocation object + */ + public CommitLogSegment.Allocation allocate(Mutation mutation, int size) + { + CommitLogSegment segment = segmentManager.getActiveSegment(); + + CommitLogSegment.Allocation alloc = segment.allocate(mutation, size); + // If we failed to allocate in the segment, prompt for a switch to a new segment and loop on re-attempt. This + // is expected to succeed or throw, since CommitLog allocation working is central to how a node operates. + while (alloc == null) + { + // Failed to allocate, so move to a new segment with enough room if possible. + segmentManager.switchToNewSegment(segment); + segment = segmentManager.getActiveSegment(); + alloc = segment.allocate(mutation, size); + } + + return alloc; + } + + public CommitLogSegment createSegment() + { + return CommitLogSegment.createSegment(segmentManager.commitLog, segmentManager); + } +} diff --git a/src/java/org/apache/cassandra/db/commitlog/AbstractCommitLogSegmentManager.java b/src/java/org/apache/cassandra/db/commitlog/CommitLogSegmentManager.java similarity index 74% rename from src/java/org/apache/cassandra/db/commitlog/AbstractCommitLogSegmentManager.java rename to src/java/org/apache/cassandra/db/commitlog/CommitLogSegmentManager.java index fdbf7f661ff3..2c3878ac1917 100755 --- a/src/java/org/apache/cassandra/db/commitlog/AbstractCommitLogSegmentManager.java +++ b/src/java/org/apache/cassandra/db/commitlog/CommitLogSegmentManager.java @@ -32,45 +32,41 @@ import net.nicoulaj.compilecommand.annotations.DontInline; import org.apache.cassandra.concurrent.NamedThreadFactory; import org.apache.cassandra.config.DatabaseDescriptor; -import org.apache.cassandra.io.util.FileUtils; -import org.apache.cassandra.schema.Schema; import org.apache.cassandra.db.*; +import org.apache.cassandra.schema.Schema; import org.apache.cassandra.schema.TableId; import org.apache.cassandra.schema.TableMetadata; -import org.apache.cassandra.utils.*; +import org.apache.cassandra.utils.JVMStabilityInspector; +import org.apache.cassandra.utils.WrappedRunnable; import org.apache.cassandra.utils.concurrent.WaitQueue; -import static org.apache.cassandra.db.commitlog.CommitLogSegment.Allocation; - /** * Performs eager-creation of commit log segments in a background thread. All the * public methods are thread safe. */ -public abstract class AbstractCommitLogSegmentManager +public class CommitLogSegmentManager { - static final Logger logger = LoggerFactory.getLogger(AbstractCommitLogSegmentManager.class); + static final Logger logger = LoggerFactory.getLogger(CommitLogSegmentManager.class); /** * Segment that is ready to be used. The management thread fills this and blocks until consumed. * * A single management thread produces this, and consumers are already synchronizing to make sure other work is * performed atomically with consuming this. Volatile to make sure writes by the management thread become - * visible (ordered/lazySet would suffice). Consumers (advanceAllocatingFrom and discardAvailableSegment) must + * visible (ordered/lazySet would suffice). Consumers (switchToNewSegment and discardAvailableSegment) must * synchronize on 'this'. */ private volatile CommitLogSegment availableSegment = null; private final WaitQueue segmentPrepared = new WaitQueue(); - /** Active segments, containing unflushed data. The tail of this queue is the one we allocate writes to */ - private final ConcurrentLinkedQueue activeSegments = new ConcurrentLinkedQueue<>(); - - /** - * The segment we are currently allocating commit log records to. - * - * Written by advanceAllocatingFrom which synchronizes on 'this'. Volatile to ensure reads get current value. + /** Segments that are still in memtables and not yet flushed to sstables. + * The tail of this queue is the one we allocate writes to. */ - private volatile CommitLogSegment allocatingFrom = null; + private final ConcurrentLinkedQueue unflushedSegments = new ConcurrentLinkedQueue<>(); + + /** The segment we are currently allocating commit log records to. */ + private volatile CommitLogSegment activeSegment = null; final String storageDirectory; @@ -83,7 +79,7 @@ public abstract class AbstractCommitLogSegmentManager private final AtomicLong size = new AtomicLong(); private Thread managerThread; - protected final CommitLog commitLog; + final CommitLog commitLog; private volatile boolean shutdown; private final BooleanSupplier managerThreadWaitCondition = () -> (availableSegment == null && !atSegmentBufferLimit()) || shutdown; private final WaitQueue managerThreadWaitQueue = new WaitQueue(); @@ -91,14 +87,20 @@ public abstract class AbstractCommitLogSegmentManager private static final SimpleCachedBufferPool bufferPool = new SimpleCachedBufferPool(DatabaseDescriptor.getCommitLogMaxCompressionBuffersInPool(), DatabaseDescriptor.getCommitLogSegmentSize()); - AbstractCommitLogSegmentManager(final CommitLog commitLog, String storageDirectory) + final CommitLogSegmentAllocator segmentAllocator; + + CommitLogSegmentManager(final CommitLog commitLog, String storageDirectory) { this.commitLog = commitLog; this.storageDirectory = storageDirectory; + this.segmentAllocator = DatabaseDescriptor.isCDCEnabled() ? + new CommitLogSegmentAllocatorCDC(this) : + new CommitLogSegmentAllocatorStandard(this); } void start() { + segmentAllocator.start(); // The run loop for the manager thread Runnable runnable = new WrappedRunnable() { @@ -110,7 +112,7 @@ public void runMayThrow() throws Exception { assert availableSegment == null; logger.trace("No segments in reserve; creating a fresh one"); - availableSegment = createSegment(); + availableSegment = segmentAllocator.createSegment(); if (shutdown) { // If shutdown() started and finished during segment creation, we are now left with a @@ -127,8 +129,8 @@ public void runMayThrow() throws Exception continue; // Writing threads are not waiting for new segments, we can spend time on other tasks. - // flush old Cfs if we're full - maybeFlushToReclaim(); + // Flush old Cfs if we're full. + flushIfOverLimit(); } catch (Throwable t) { @@ -153,7 +155,7 @@ public void runMayThrow() throws Exception managerThread.start(); // for simplicity, ensure the first segment is allocated before continuing - advanceAllocatingFrom(null); + switchToNewSegment(null); } private boolean atSegmentBufferLimit() @@ -161,75 +163,58 @@ private boolean atSegmentBufferLimit() return CommitLogSegment.usesBufferPool(commitLog) && bufferPool.atLimit(); } - private void maybeFlushToReclaim() + /** + * In the event we've overallocated (i.e. size on disk > limit in config), we want to trigger a flush of the number + * of required memtables to sstables in order to be able to reclaim some CL space on disk. + */ + private void flushIfOverLimit() { - long unused = unusedCapacity(); - if (unused < 0) + if (overConfigDiskCapacity(0)) { long flushingSize = 0; List segmentsToRecycle = new ArrayList<>(); - for (CommitLogSegment segment : activeSegments) + for (CommitLogSegment segment : unflushedSegments) { - if (segment == allocatingFrom) + if (segment == activeSegment) break; flushingSize += segment.onDiskSize(); segmentsToRecycle.add(segment); - if (flushingSize + unused >= 0) + if (!overConfigDiskCapacity((flushingSize))) break; } - flushDataFrom(segmentsToRecycle, false); + flushTablesForSegments(segmentsToRecycle, false); } } /** - * Allocate a segment within this CLSM. Should either succeed or throw. - */ - public abstract Allocation allocate(Mutation mutation, int size); - - /** - * Hook to allow segment managers to track state surrounding creation of new segments. Onl perform as task submit - * to segment manager so it's performed on segment management thread. - */ - abstract CommitLogSegment createSegment(); - - /** - * Indicates that a segment file has been flushed and is no longer needed. Only perform as task submit to segment - * manager so it's performend on segment management thread, or perform while segment management thread is shutdown - * during testing resets. - * - * @param segment segment to be discarded - * @param delete whether or not the segment is safe to be deleted. - */ - abstract void discard(CommitLogSegment segment, boolean delete); - - /** - * Advances the allocatingFrom pointer to the next prepared segment, but only if it is currently the segment provided. + * Advances the activeSegment pointer to the next prepared segment, but only if it is currently the segment provided. * * WARNING: Assumes segment management thread always succeeds in allocating a new segment or kills the JVM. */ @DontInline - void advanceAllocatingFrom(CommitLogSegment old) + void switchToNewSegment(CommitLogSegment old) { while (true) { synchronized (this) { - // do this in a critical section so we can maintain the order of segment construction when moving to allocatingFrom/activeSegments - if (allocatingFrom != old) + // do this in a critical section so we can maintain the order of segment construction when moving to activeSegment/unflushedSegments + if (activeSegment != old) return; // If a segment is ready, take it now, otherwise wait for the management thread to construct it. if (availableSegment != null) { - // Success! Change allocatingFrom and activeSegments (which must be kept in order) before leaving + // Success! Change activeSegment and unflushedSegments (which must be kept in order) before leaving // the critical section. - activeSegments.add(allocatingFrom = availableSegment); + activeSegment = availableSegment; + unflushedSegments.add(activeSegment); availableSegment = null; break; } } - awaitAvailableSegment(old); + awaitSegmentAllocation(old); } // Signal the management thread to prepare a new segment. @@ -249,17 +234,22 @@ void advanceAllocatingFrom(CommitLogSegment old) commitLog.requestExtraSync(); } - void awaitAvailableSegment(CommitLogSegment currentAllocatingFrom) + /** + * Spins while waiting on next available segment's allocation, putting caller to sleep until the new segment is created. + * @param oldActiveSegment + */ + private void awaitSegmentAllocation(CommitLogSegment oldActiveSegment) { do { WaitQueue.Signal prepared = segmentPrepared.register(commitLog.metrics.waitingOnSegmentAllocation.time()); - if (availableSegment == null && allocatingFrom == currentAllocatingFrom) + // No new segment created, and the active segment is the one we already know about. Time to sleep... + if (availableSegment == null && activeSegment == oldActiveSegment) prepared.awaitUninterruptibly(); else prepared.cancel(); } - while (availableSegment == null && allocatingFrom == currentAllocatingFrom); + while (availableSegment == null && activeSegment == oldActiveSegment); } /** @@ -269,9 +259,9 @@ void awaitAvailableSegment(CommitLogSegment currentAllocatingFrom) */ void forceRecycleAll(Iterable droppedTables) { - List segmentsToRecycle = new ArrayList<>(activeSegments); + List segmentsToRecycle = new ArrayList<>(unflushedSegments); CommitLogSegment last = segmentsToRecycle.get(segmentsToRecycle.size() - 1); - advanceAllocatingFrom(last); + switchToNewSegment(last); // wait for the commit log modifications last.waitForModifications(); @@ -281,26 +271,26 @@ void forceRecycleAll(Iterable droppedTables) Keyspace.writeOrder.awaitNewBarrier(); // flush and wait for all CFs that are dirty in segments up-to and including 'last' - Future future = flushDataFrom(segmentsToRecycle, true); + Future future = flushTablesForSegments(segmentsToRecycle, true); try { future.get(); - for (CommitLogSegment segment : activeSegments) + for (CommitLogSegment segment : unflushedSegments) for (TableId tableId : droppedTables) segment.markClean(tableId, CommitLogPosition.NONE, segment.getCurrentCommitLogPosition()); // now recycle segments that are unused, as we may not have triggered a discardCompletedSegments() // if the previous active segment was the only one to recycle (since an active segment isn't // necessarily dirty, and we only call dCS after a flush). - for (CommitLogSegment segment : activeSegments) + for (CommitLogSegment segment : unflushedSegments) { if (segment.isUnused()) archiveAndDiscard(segment); } - CommitLogSegment first; - if ((first = activeSegments.peek()) != null && first.id <= last.id) + CommitLogSegment first = unflushedSegments.peek(); + if (first != null && first.id <= last.id) logger.error("Failed to force-recycle all segments; at least one segment is still in use with dirty CFs."); } catch (Throwable t) @@ -318,11 +308,11 @@ void forceRecycleAll(Iterable droppedTables) void archiveAndDiscard(final CommitLogSegment segment) { boolean archiveSuccess = commitLog.archiver.maybeWaitForArchiving(segment.getName()); - if (!activeSegments.remove(segment)) + if (!unflushedSegments.remove(segment)) return; // already discarded // if archiving (command) was not successful then leave the file alone. don't delete or recycle. logger.debug("Segment {} is no longer active and will be deleted {}", segment, archiveSuccess ? "now" : "by the archive script"); - discard(segment, archiveSuccess); + segmentAllocator.discard(segment, archiveSuccess); } /** @@ -332,9 +322,7 @@ void archiveAndDiscard(final CommitLogSegment segment) */ void handleReplayedSegment(final File file) { - // (don't decrease managed size, since this was never a "live" segment) - logger.trace("(Unopened) segment {} is no longer needed and will be deleted now", file); - FileUtils.deleteWithConfirm(file); + segmentAllocator.handleReplayedSegment(file); } /** @@ -354,12 +342,16 @@ public long onDiskSize() return size.get(); } - private long unusedCapacity() + /** + * We offset by the amount we've planned to flush with to allow for selective calculation up front of how much to flush + */ + private boolean overConfigDiskCapacity(long toBeFlushed) { long total = DatabaseDescriptor.getTotalCommitlogSpaceInMB() * 1024 * 1024; - long currentSize = size.get(); + long currentSize = size.get() + toBeFlushed; logger.trace("Total active commitlog segment space used is {} out of {}", currentSize, total); - return total - currentSize; + // TODO: Consider whether to do >=. Original logic strictly equated with > from CASSANDRA-9095 + return currentSize > total; } /** @@ -367,7 +359,7 @@ private long unusedCapacity() * * @return a Future that will finish when all the flushes are complete. */ - private Future flushDataFrom(List segments, boolean force) + private Future flushTablesForSegments(List segments, boolean force) { if (segments.isEmpty()) return Futures.immediateFuture(null); @@ -393,7 +385,7 @@ else if (!flushes.containsKey(dirtyTableId)) final ColumnFamilyStore cfs = Keyspace.open(metadata.keyspace).getColumnFamilyStore(dirtyTableId); // can safely call forceFlush here as we will only ever block (briefly) for other attempts to flush, // no deadlock possibility since switchLock removal - flushes.put(dirtyTableId, force ? cfs.forceFlush() : cfs.forceFlush(maxCommitLogPosition)); + flushes.put(dirtyTableId, force ? cfs.forceFlushToSSTable() : cfs.forceFlushToSSTable(maxCommitLogPosition)); } } } @@ -405,7 +397,7 @@ else if (!flushes.containsKey(dirtyTableId)) * Stops CL, for testing purposes. DO NOT USE THIS OUTSIDE OF TESTS. * Only call this after the AbstractCommitLogService is shut down. */ - public void stopUnsafe(boolean deleteSegments) + void stopUnsafe(boolean deleteSegments) { logger.debug("CLSM closing and clearing existing commit log segments..."); @@ -419,9 +411,9 @@ public void stopUnsafe(boolean deleteSegments) throw new RuntimeException(e); } - for (CommitLogSegment segment : activeSegments) + for (CommitLogSegment segment : unflushedSegments) closeAndDeleteSegmentUnsafe(segment, deleteSegments); - activeSegments.clear(); + unflushedSegments.clear(); size.set(0L); @@ -435,7 +427,7 @@ void awaitManagementTasksCompletion() { if (availableSegment == null && !atSegmentBufferLimit()) { - awaitAvailableSegment(allocatingFrom); + awaitSegmentAllocation(activeSegment); } } @@ -446,7 +438,7 @@ private void closeAndDeleteSegmentUnsafe(CommitLogSegment segment, boolean delet { try { - discard(segment, delete); + segmentAllocator.discard(segment, delete); } catch (AssertionError ignored) { @@ -455,13 +447,15 @@ private void closeAndDeleteSegmentUnsafe(CommitLogSegment segment, boolean delet } /** - * Initiates the shutdown process for the management thread. + * Initiates the shutdown process for the management thread and segment allocator. */ public void shutdown() { assert !shutdown; shutdown = true; + segmentAllocator.shutdown(); + // Release the management thread and delete prepared segment. // Do not block as another thread may claim the segment (this can happen during unit test initialization). discardAvailableSegment(); @@ -488,19 +482,21 @@ public void awaitTermination() throws InterruptedException managerThread.join(); managerThread = null; - for (CommitLogSegment segment : activeSegments) + for (CommitLogSegment segment : unflushedSegments) segment.close(); bufferPool.shutdown(); } /** - * @return a read-only collection of the active commit log segments + * @return a read-only collection of all active and unflushed segments in the system. In this context, "Flushed" is + * referring to "memtable / CF flushed to sstables", not whether or not the CommitLogSegment itself is flushed via + * fsync. */ @VisibleForTesting - public Collection getActiveSegments() + public Collection getSegmentsForUnflushedTables() { - return Collections.unmodifiableCollection(activeSegments); + return Collections.unmodifiableCollection(unflushedSegments); } /** @@ -508,7 +504,7 @@ public Collection getActiveSegments() */ CommitLogPosition getCurrentPosition() { - return allocatingFrom.getCurrentCommitLogPosition(); + return activeSegment.getCurrentCommitLogPosition(); } /** @@ -518,8 +514,8 @@ CommitLogPosition getCurrentPosition() */ public void sync(boolean flush) throws IOException { - CommitLogSegment current = allocatingFrom; - for (CommitLogSegment segment : getActiveSegments()) + CommitLogSegment current = activeSegment; + for (CommitLogSegment segment : getSegmentsForUnflushedTables()) { // Do not sync segments that became active after sync started. if (segment.id > current.id) @@ -550,10 +546,17 @@ void notifyBufferFreed() wakeManager(); } - /** Read-only access to current segment for subclasses. */ - CommitLogSegment allocatingFrom() + /** + * Pass-through call to allocator. Allocates a mutation in the active CommitLogSegment. + */ + CommitLogSegment.Allocation allocate(Mutation mutation, int size) { + return segmentAllocator.allocate(mutation, size); + } + + /** Read-only access to current segment for package usage. */ + CommitLogSegment getActiveSegment() { - return allocatingFrom; + return activeSegment; } } diff --git a/src/java/org/apache/cassandra/db/commitlog/CommitLogSegmentManagerStandard.java b/src/java/org/apache/cassandra/db/commitlog/CommitLogSegmentManagerStandard.java deleted file mode 100644 index b9bd744da1a5..000000000000 --- a/src/java/org/apache/cassandra/db/commitlog/CommitLogSegmentManagerStandard.java +++ /dev/null @@ -1,68 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.cassandra.db.commitlog; - -import java.io.File; - -import org.apache.cassandra.db.Mutation; -import org.apache.cassandra.io.util.FileUtils; - -public class CommitLogSegmentManagerStandard extends AbstractCommitLogSegmentManager -{ - public CommitLogSegmentManagerStandard(final CommitLog commitLog, String storageDirectory) - { - super(commitLog, storageDirectory); - } - - public void discard(CommitLogSegment segment, boolean delete) - { - segment.close(); - if (delete) - FileUtils.deleteWithConfirm(segment.logFile); - addSize(-segment.onDiskSize()); - } - - /** - * Reserve space in the current segment for the provided mutation or, if there isn't space available, - * create a new segment. allocate() is blocking until allocation succeeds as it waits on a signal in advanceAllocatingFrom - * - * @param mutation mutation to allocate space for - * @param size total size of mutation (overhead + serialized size) - * @return the provided Allocation object - */ - public CommitLogSegment.Allocation allocate(Mutation mutation, int size) - { - CommitLogSegment segment = allocatingFrom(); - - CommitLogSegment.Allocation alloc; - while ( null == (alloc = segment.allocate(mutation, size)) ) - { - // failed to allocate, so move to a new segment with enough room - advanceAllocatingFrom(segment); - segment = allocatingFrom(); - } - - return alloc; - } - - public CommitLogSegment createSegment() - { - return CommitLogSegment.createSegment(commitLog, this); - } -} diff --git a/src/java/org/apache/cassandra/db/commitlog/CommitLogSegmentReader.java b/src/java/org/apache/cassandra/db/commitlog/CommitLogSegmentReader.java index e23a915ba355..fd8ff3078ddb 100644 --- a/src/java/org/apache/cassandra/db/commitlog/CommitLogSegmentReader.java +++ b/src/java/org/apache/cassandra/db/commitlog/CommitLogSegmentReader.java @@ -21,59 +21,64 @@ import java.nio.ByteBuffer; import java.util.Iterator; import java.util.zip.CRC32; +import javax.annotation.concurrent.NotThreadSafe; import javax.crypto.Cipher; import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.AbstractIterator; +import org.apache.cassandra.db.commitlog.CommitLogReadHandler.CommitLogReadErrorReason; +import org.apache.cassandra.db.commitlog.CommitLogReadHandler.CommitLogReadException; import org.apache.cassandra.db.commitlog.EncryptedFileSegmentInputStream.ChunkProvider; -import org.apache.cassandra.db.commitlog.CommitLogReadHandler.*; import org.apache.cassandra.io.FSReadError; import org.apache.cassandra.io.compress.ICompressor; import org.apache.cassandra.io.util.FileDataInput; import org.apache.cassandra.io.util.FileSegmentInputStream; import org.apache.cassandra.io.util.RandomAccessReader; import org.apache.cassandra.schema.CompressionParams; -import org.apache.cassandra.security.EncryptionUtils; import org.apache.cassandra.security.EncryptionContext; +import org.apache.cassandra.security.EncryptionUtils; import org.apache.cassandra.utils.ByteBufferUtil; import static org.apache.cassandra.db.commitlog.CommitLogSegment.SYNC_MARKER_SIZE; import static org.apache.cassandra.utils.FBUtilities.updateChecksumInt; /** - * Read each sync section of a commit log, iteratively. + * Read each sync section of a commit log, iteratively. Can be run in either one-shot or resumable mode. In resumable, + * we snapshot the start position of any successful SyncSegment deserialization with the expectation that some reads will + * land in partially written segments and need to be rolled back to the start of that segment and repeated on further + * mutation serialization (specifically in encrypted or compressed contexts). */ +@NotThreadSafe public class CommitLogSegmentReader implements Iterable { - private final CommitLogReadHandler handler; - private final CommitLogDescriptor descriptor; - private final RandomAccessReader reader; + private final ResumableCommitLogReader parent; private final Segmenter segmenter; - private final boolean tolerateTruncation; - /** - * ending position of the current sync section. - */ + /** A special SyncSegment we use to indicate / keep our iterators open on a read we intend to resume */ + static final SyncSegment RESUMABLE_SENTINEL = new SyncSegment(null, -1, -1, -1, false); + + /** Ending position of the current sync section. */ protected int end; - protected CommitLogSegmentReader(CommitLogReadHandler handler, - CommitLogDescriptor descriptor, - RandomAccessReader reader, - boolean tolerateTruncation) + /** + * Rather than relying on a formal Builder, this constructs the appropriate type of segment reader (memmap, encrypted, + * compressed) based on the type stored in the descriptor. + * + * Note: If ever using this object directly in a test, ensure you set the {@link ResumableCommitLogReader#offsetLimit} + * before attempting to use this reader or iteration will never advance. + */ + CommitLogSegmentReader(ResumableCommitLogReader parent) { - this.handler = handler; - this.descriptor = descriptor; - this.reader = reader; - this.tolerateTruncation = tolerateTruncation; - - end = (int) reader.getFilePointer(); - if (descriptor.getEncryptionContext().isEnabled()) - segmenter = new EncryptedSegmenter(descriptor, reader); - else if (descriptor.compression != null) - segmenter = new CompressedSegmenter(descriptor, reader); + this.parent = parent; + + end = (int) parent.rawReader.getFilePointer(); + if (parent.descriptor.getEncryptionContext().isEnabled()) + segmenter = new EncryptedSegmenter(parent.descriptor, parent); + else if (parent.descriptor.compression != null) + segmenter = new CompressedSegmenter(parent.descriptor, parent); else - segmenter = new NoOpSegmenter(reader); + segmenter = new NoOpSegmenter(parent.rawReader); } public Iterator iterator() @@ -81,36 +86,75 @@ public Iterator iterator() return new SegmentIterator(); } + /** Will return endOfData() or our resumable sentinel depending on what mode the iterator is being used in */ protected class SegmentIterator extends AbstractIterator { protected SyncSegment computeNext() { + // A couple sanity checks that we're in a good state + if (parent.offsetLimit == Integer.MIN_VALUE) + throw new RuntimeException("Attempted to use a CommitLogSegmentReader with an uninitialized ResumableCommitLogReader parent."); + + // Since this could be mis-used by client app parsing code, keep it RTE instead of assertion. + if (parent.isClosed) + throw new RuntimeException("Attempted to use a closed ResumableCommitLogReader."); + while (true) { try { final int currentStart = end; - end = readSyncMarker(descriptor, currentStart, reader); - if (end == -1) + + // Segmenters need to know our original state to appropriately roll back on snapshot restore + segmenter.stageSnapshot(); + end = readSyncMarker(parent.descriptor, currentStart, parent.rawReader); + + if (parent.isPartial()) { - return endOfData(); + // Revert our SegmentIterator's state to beginning of last completed SyncSegment read on a partial read. + if (end == -1 || end > parent.offsetLimit) + { + segmenter.revertToSnapshot(); + end = (int)parent.rawReader.getFilePointer(); + return RESUMABLE_SENTINEL; + } + // Flag our RR's data as exhausted if we've hit the end of our reader but think this is partial. + else if (end >= parent.rawReader.length()) + { + parent.readToExhaustion = true; + } } - if (end > reader.length()) + // Iterate on a non-resumable read. + else { - // the CRC was good (meaning it was good when it was written and still looks legit), but the file is truncated now. - // try to grab and use as much of the file as possible, which might be nothing if the end of the file truly is corrupt - end = (int) reader.length(); + if (end == -1) + { + // We only transition to endOfData if we're doing a non-resumable (i.e. read to end) read, + // since it leaves this iterator in a non-reusable state. + return endOfData(); + } + else if (end > parent.rawReader.length()) + { + // the CRC was good (meaning it was good when it was written and still looks legit), but the file is truncated now. + // try to grab and use as much of the file as possible, which might be nothing if the end of the file truly is corrupt + end = (int) parent.rawReader.length(); + } } + + // Retain the starting point of this SyncSegment in case we need to roll back a future read to this point. + segmenter.takeSnapshot(); + + // Passed the gauntlet. The next segment is cleanly ready for read. return segmenter.nextSegment(currentStart + SYNC_MARKER_SIZE, end); } catch(CommitLogSegmentReader.SegmentReadException e) { try { - handler.handleUnrecoverableError(new CommitLogReadException( - e.getMessage(), - CommitLogReadErrorReason.UNRECOVERABLE_DESCRIPTOR_ERROR, - !e.invalidCrc && tolerateTruncation)); + parent.readHandler.handleUnrecoverableError(new CommitLogReadException( + e.getMessage(), + CommitLogReadErrorReason.UNRECOVERABLE_DESCRIPTOR_ERROR, + !e.invalidCrc && parent.tolerateTruncation)); } catch (IOException ioe) { @@ -121,12 +165,12 @@ protected SyncSegment computeNext() { try { - boolean tolerateErrorsInSection = tolerateTruncation & segmenter.tolerateSegmentErrors(end, reader.length()); + boolean tolerateErrorsInSection = parent.tolerateTruncation & segmenter.tolerateSegmentErrors(end, parent.rawReader.length()); // if no exception is thrown, the while loop will continue - handler.handleUnrecoverableError(new CommitLogReadException( - e.getMessage(), - CommitLogReadErrorReason.UNRECOVERABLE_DESCRIPTOR_ERROR, - tolerateErrorsInSection)); + parent.readHandler.handleUnrecoverableError(new CommitLogReadException( + e.getMessage(), + CommitLogReadErrorReason.UNRECOVERABLE_DESCRIPTOR_ERROR, + tolerateErrorsInSection)); } catch (IOException ioe) { @@ -137,13 +181,13 @@ protected SyncSegment computeNext() } } + /** + * @return length of this sync segment, -1 if at or beyond the end of file. + */ private int readSyncMarker(CommitLogDescriptor descriptor, int offset, RandomAccessReader reader) throws IOException { if (offset > reader.length() - SYNC_MARKER_SIZE) - { - // There was no room in the segment to write a final header. No data could be present here. return -1; - } reader.seek(offset); CRC32 crc = new CRC32(); updateChecksumInt(crc, (int) (descriptor.id & 0xFFFFFFFFL)); @@ -180,23 +224,24 @@ public SegmentReadException(String msg, boolean invalidCrc) } } + /** The logical unit of data we sync across and read across in CommitLogs. */ public static class SyncSegment { /** the 'buffer' to replay commit log data from */ public final FileDataInput input; /** offset in file where this section begins. */ - public final int fileStartPosition; + final int fileStartPosition; /** offset in file where this section ends. */ - public final int fileEndPosition; + final int fileEndPosition; /** the logical ending position of the buffer */ - public final int endPosition; + final int endPosition; - public final boolean toleratesErrorsInSection; + final boolean toleratesErrorsInSection; - public SyncSegment(FileDataInput input, int fileStartPosition, int fileEndPosition, int endPosition, boolean toleratesErrorsInSection) + SyncSegment(FileDataInput input, int fileStartPosition, int fileEndPosition, int endPosition, boolean toleratesErrorsInSection) { this.input = input; this.fileStartPosition = fileStartPosition; @@ -208,6 +253,8 @@ public SyncSegment(FileDataInput input, int fileStartPosition, int fileEndPositi /** * Derives the next section of the commit log to be replayed. Section boundaries are derived from the commit log sync markers. + * Allows snapshot and resume from snapshot functionality to revert to a "last known good segment" in the event of + * a partial read on an a file being actively written. */ interface Segmenter { @@ -228,13 +275,29 @@ default boolean tolerateSegmentErrors(int segmentEndPosition, long fileLength) { return segmentEndPosition >= fileLength || segmentEndPosition < 0; } + + /** Holds snapshot data in temporary variables to be finalized when we determine a SyncSegment is fully written */ + void stageSnapshot(); + + /** Finalizes snapshot staged in stageSnapshot */ + void takeSnapshot(); + + /** Reverts the segmenter to the previously held position. Allows for resumable reads to rollback when they occur + * in the middle of a SyncSegment. This can be called repeatedly if we have multiple attempts to partially read + * on an incomplete SyncSegment. */ + void revertToSnapshot(); + + /** Visible for debugging only */ + long getSnapshot(); } static class NoOpSegmenter implements Segmenter { private final RandomAccessReader reader; + private long snapshotPosition = Long.MIN_VALUE; + private long stagedSnapshot = Long.MIN_VALUE; - public NoOpSegmenter(RandomAccessReader reader) + NoOpSegmenter(RandomAccessReader reader) { this.reader = reader; } @@ -249,54 +312,114 @@ public boolean tolerateSegmentErrors(int end, long length) { return true; } + + public void stageSnapshot() + { + stagedSnapshot = reader.getFilePointer(); + // Deal with edge case of initial read attempt being before SyncSegment completion + if (snapshotPosition == Long.MIN_VALUE) + takeSnapshot(); + } + + public void takeSnapshot() + { + snapshotPosition = stagedSnapshot; + } + + public void revertToSnapshot() + { + reader.seek(snapshotPosition); + } + + public long getSnapshot() + { + return snapshotPosition; + } } static class CompressedSegmenter implements Segmenter { private final ICompressor compressor; - private final RandomAccessReader reader; + /** We store a reference to a ResumableReader in the event it needs to re-init and swap out the underlying reader */ + private final ResumableCommitLogReader parent; private byte[] compressedBuffer; private byte[] uncompressedBuffer; private long nextLogicalStart; - public CompressedSegmenter(CommitLogDescriptor desc, RandomAccessReader reader) + private long stagedLogicalStart = Long.MIN_VALUE; + private long stagedReaderLocation = Long.MIN_VALUE; + private long snapshotLogicalStart = Long.MIN_VALUE; + private long snapshotReaderLocation = Long.MIN_VALUE; + + CompressedSegmenter(CommitLogDescriptor desc, ResumableCommitLogReader parent) { - this(CompressionParams.createCompressor(desc.compression), reader); + this(CompressionParams.createCompressor(desc.compression), parent); } - public CompressedSegmenter(ICompressor compressor, RandomAccessReader reader) + CompressedSegmenter(ICompressor compressor, ResumableCommitLogReader parent) { this.compressor = compressor; - this.reader = reader; + this.parent = parent; compressedBuffer = new byte[0]; uncompressedBuffer = new byte[0]; - nextLogicalStart = reader.getFilePointer(); + nextLogicalStart = parent.rawReader.getFilePointer(); } @SuppressWarnings("resource") public SyncSegment nextSegment(final int startPosition, final int nextSectionStartPosition) throws IOException { - reader.seek(startPosition); - int uncompressedLength = reader.readInt(); + parent.rawReader.seek(startPosition); + int uncompressedLength = parent.rawReader.readInt(); - int compressedLength = nextSectionStartPosition - (int)reader.getPosition(); + int compressedLength = nextSectionStartPosition - (int)parent.rawReader.getPosition(); if (compressedLength > compressedBuffer.length) compressedBuffer = new byte[(int) (1.2 * compressedLength)]; - reader.readFully(compressedBuffer, 0, compressedLength); + parent.rawReader.readFully(compressedBuffer, 0, compressedLength); if (uncompressedLength > uncompressedBuffer.length) uncompressedBuffer = new byte[(int) (1.2 * uncompressedLength)]; int count = compressor.uncompress(compressedBuffer, 0, compressedLength, uncompressedBuffer, 0); nextLogicalStart += SYNC_MARKER_SIZE; - FileDataInput input = new FileSegmentInputStream(ByteBuffer.wrap(uncompressedBuffer, 0, count), reader.getPath(), nextLogicalStart); + FileDataInput input = new FileSegmentInputStream(ByteBuffer.wrap(uncompressedBuffer, 0, count), parent.rawReader.getPath(), nextLogicalStart); nextLogicalStart += uncompressedLength; - return new SyncSegment(input, startPosition, nextSectionStartPosition, (int)nextLogicalStart, tolerateSegmentErrors(nextSectionStartPosition, reader.length())); + return new SyncSegment(input, startPosition, nextSectionStartPosition, (int)nextLogicalStart, tolerateSegmentErrors(nextSectionStartPosition, parent.rawReader.length())); + } + + public void stageSnapshot() + { + stagedLogicalStart = nextLogicalStart; + stagedReaderLocation = parent.rawReader.getFilePointer(); + + // In our default 0 case on a segment w/out anything yet to read, we want to stage the first valid location + // we've seen, else a resume will kick us to a bad value + if (snapshotLogicalStart == Long.MIN_VALUE) + takeSnapshot(); + } + + /** Since {@link #nextLogicalStart} is mutated during decompression but relied upon for decompression, we need + * to both snapshot and revert that along with the reader's position. */ + public void takeSnapshot() + { + snapshotLogicalStart = stagedLogicalStart; + snapshotReaderLocation = stagedReaderLocation; + } + + public void revertToSnapshot() + { + nextLogicalStart = snapshotLogicalStart; + parent.rawReader.seek(snapshotReaderLocation); + } + + public long getSnapshot() + { + return snapshotReaderLocation; } } static class EncryptedSegmenter implements Segmenter { - private final RandomAccessReader reader; + /** We store a reference to a ResumableReader in the event it needs to re-init and swap out the underlying reader */ + private final ResumableCommitLogReader parent; private final ICompressor compressor; private final Cipher cipher; @@ -315,18 +438,21 @@ static class EncryptedSegmenter implements Segmenter private long currentSegmentEndPosition; private long nextLogicalStart; - public EncryptedSegmenter(CommitLogDescriptor descriptor, RandomAccessReader reader) + private long stagedSnapshotPosition; + private long snapshotPosition; + + EncryptedSegmenter(CommitLogDescriptor descriptor, ResumableCommitLogReader parent) { - this(reader, descriptor.getEncryptionContext()); + this(parent, descriptor.getEncryptionContext()); } @VisibleForTesting - EncryptedSegmenter(final RandomAccessReader reader, EncryptionContext encryptionContext) + EncryptedSegmenter(final ResumableCommitLogReader parent, EncryptionContext encryptionContext) { - this.reader = reader; + this.parent = parent; decryptedBuffer = ByteBuffer.allocate(0); compressor = encryptionContext.getCompressor(); - nextLogicalStart = reader.getFilePointer(); + nextLogicalStart = parent.rawReader.getFilePointer(); try { @@ -334,21 +460,21 @@ public EncryptedSegmenter(CommitLogDescriptor descriptor, RandomAccessReader rea } catch (IOException ioe) { - throw new FSReadError(ioe, reader.getPath()); + throw new FSReadError(ioe, parent.rawReader.getPath()); } chunkProvider = () -> { - if (reader.getFilePointer() >= currentSegmentEndPosition) + if (parent.rawReader.getFilePointer() >= currentSegmentEndPosition) return ByteBufferUtil.EMPTY_BYTE_BUFFER; try { - decryptedBuffer = EncryptionUtils.decrypt(reader, decryptedBuffer, true, cipher); + decryptedBuffer = EncryptionUtils.decrypt(parent.rawReader, decryptedBuffer, true, cipher); uncompressedBuffer = EncryptionUtils.uncompress(decryptedBuffer, uncompressedBuffer, true, compressor); return uncompressedBuffer; } catch (IOException e) { - throw new FSReadError(e, reader.getPath()); + throw new FSReadError(e, parent.rawReader.getPath()); } }; } @@ -356,13 +482,35 @@ public EncryptedSegmenter(CommitLogDescriptor descriptor, RandomAccessReader rea @SuppressWarnings("resource") public SyncSegment nextSegment(int startPosition, int nextSectionStartPosition) throws IOException { - int totalPlainTextLength = reader.readInt(); + int totalPlainTextLength = parent.rawReader.readInt(); currentSegmentEndPosition = nextSectionStartPosition - 1; nextLogicalStart += SYNC_MARKER_SIZE; - FileDataInput input = new EncryptedFileSegmentInputStream(reader.getPath(), nextLogicalStart, 0, totalPlainTextLength, chunkProvider); + FileDataInput input = new EncryptedFileSegmentInputStream(parent.rawReader.getPath(), nextLogicalStart, 0, totalPlainTextLength, chunkProvider); nextLogicalStart += totalPlainTextLength; - return new SyncSegment(input, startPosition, nextSectionStartPosition, (int)nextLogicalStart, tolerateSegmentErrors(nextSectionStartPosition, reader.length())); + return new SyncSegment(input, startPosition, nextSectionStartPosition, (int)nextLogicalStart, tolerateSegmentErrors(nextSectionStartPosition, parent.rawReader.length())); + } + + public void stageSnapshot() + { + stagedSnapshotPosition = parent.rawReader.getFilePointer(); + if (snapshotPosition == Long.MIN_VALUE) + takeSnapshot(); + } + + public void takeSnapshot() + { + snapshotPosition = stagedSnapshotPosition; + } + + public void revertToSnapshot() + { + parent.rawReader.seek(snapshotPosition); + } + + public long getSnapshot() + { + return snapshotPosition; } } } diff --git a/src/java/org/apache/cassandra/db/commitlog/CompressedSegment.java b/src/java/org/apache/cassandra/db/commitlog/CompressedSegment.java index d5e61137d842..c74e8430b218 100644 --- a/src/java/org/apache/cassandra/db/commitlog/CompressedSegment.java +++ b/src/java/org/apache/cassandra/db/commitlog/CompressedSegment.java @@ -41,7 +41,7 @@ public class CompressedSegment extends FileDirectSegment /** * Constructs a new segment file. */ - CompressedSegment(CommitLog commitLog, AbstractCommitLogSegmentManager manager) + CompressedSegment(CommitLog commitLog, CommitLogSegmentManager manager) { super(commitLog, manager); this.compressor = commitLog.configuration.getCompressor(); @@ -59,7 +59,7 @@ void write(int startMarker, int nextMarker) int contentStart = startMarker + SYNC_MARKER_SIZE; int length = nextMarker - contentStart; // The length may be 0 when the segment is being closed. - assert length > 0 || length == 0 && !isStillAllocating(); + assert length > 0 || length == 0 && !hasRoom(); try { diff --git a/src/java/org/apache/cassandra/db/commitlog/EncryptedSegment.java b/src/java/org/apache/cassandra/db/commitlog/EncryptedSegment.java index 21b7c11fb052..b699da438290 100644 --- a/src/java/org/apache/cassandra/db/commitlog/EncryptedSegment.java +++ b/src/java/org/apache/cassandra/db/commitlog/EncryptedSegment.java @@ -64,7 +64,7 @@ public class EncryptedSegment extends FileDirectSegment private final EncryptionContext encryptionContext; private final Cipher cipher; - public EncryptedSegment(CommitLog commitLog, AbstractCommitLogSegmentManager manager) + public EncryptedSegment(CommitLog commitLog, CommitLogSegmentManager manager) { super(commitLog, manager); this.encryptionContext = commitLog.configuration.getEncryptionContext(); @@ -101,7 +101,7 @@ void write(int startMarker, int nextMarker) int contentStart = startMarker + SYNC_MARKER_SIZE; final int length = nextMarker - contentStart; // The length may be 0 when the segment is being closed. - assert length > 0 || length == 0 && !isStillAllocating(); + assert length > 0 || length == 0 && !hasRoom(); final ICompressor compressor = encryptionContext.getCompressor(); final int blockSize = encryptionContext.getChunkLength(); diff --git a/src/java/org/apache/cassandra/db/commitlog/FileDirectSegment.java b/src/java/org/apache/cassandra/db/commitlog/FileDirectSegment.java index d5431f875b5f..d82fd10d40a6 100644 --- a/src/java/org/apache/cassandra/db/commitlog/FileDirectSegment.java +++ b/src/java/org/apache/cassandra/db/commitlog/FileDirectSegment.java @@ -31,7 +31,7 @@ public abstract class FileDirectSegment extends CommitLogSegment { volatile long lastWrittenPos = 0; - FileDirectSegment(CommitLog commitLog, AbstractCommitLogSegmentManager manager) + FileDirectSegment(CommitLog commitLog, CommitLogSegmentManager manager) { super(commitLog, manager); } diff --git a/src/java/org/apache/cassandra/db/commitlog/MemoryMappedSegment.java b/src/java/org/apache/cassandra/db/commitlog/MemoryMappedSegment.java index 6ecdbd3c7764..9009e4cd2250 100644 --- a/src/java/org/apache/cassandra/db/commitlog/MemoryMappedSegment.java +++ b/src/java/org/apache/cassandra/db/commitlog/MemoryMappedSegment.java @@ -22,6 +22,9 @@ import java.nio.MappedByteBuffer; import java.nio.channels.FileChannel; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + import org.apache.cassandra.config.DatabaseDescriptor; import org.apache.cassandra.io.FSWriteError; import org.apache.cassandra.io.util.FileUtils; @@ -35,12 +38,14 @@ */ public class MemoryMappedSegment extends CommitLogSegment { + private static final Logger logger = LoggerFactory.getLogger(CommitLogSegmentReader.class); + /** * Constructs a new segment file. * * @param commitLog the commit log it will be used with. */ - MemoryMappedSegment(CommitLog commitLog, AbstractCommitLogSegmentManager manager) + MemoryMappedSegment(CommitLog commitLog, CommitLogSegmentManager manager) { super(commitLog, manager); // mark the initial sync marker as uninitialised diff --git a/src/java/org/apache/cassandra/db/commitlog/ResumableCommitLogReader.java b/src/java/org/apache/cassandra/db/commitlog/ResumableCommitLogReader.java new file mode 100644 index 000000000000..ac925bb5f36c --- /dev/null +++ b/src/java/org/apache/cassandra/db/commitlog/ResumableCommitLogReader.java @@ -0,0 +1,272 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.cassandra.db.commitlog; + +import java.io.File; +import java.io.FileNotFoundException; +import java.io.IOException; +import java.text.MessageFormat; +import java.util.Iterator; +import java.util.Optional; +import javax.annotation.Nullable; +import javax.annotation.concurrent.NotThreadSafe; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.cassandra.io.util.RandomAccessReader; +import org.apache.cassandra.utils.Pair; + +/** + * A state holder for a potentially resumable read. As we want to resume our reading with existing file pointers, buffers, + * and sentinels without re-opening a file and re-decompressing or decrypting, we store references to a {@link CommitLogSegmentReader}, + * to a {@link RandomAccessReader}, and to an iterator of {@link CommitLogSegmentReader.SyncSegment} here for re-use. + * + * This serves dual purpose as an API endpoint and logical state holder to pop out our handles across multiple reads + * while minimizing pollution to core CommitLogReader code implementation. + * + * _Mandatory_ usage of this API is as follows: + * 0-N) {@link #readPartial(int limit)} + * 1) {@link #readToCompletion()} + * NOTE: neither of these callse will {@link #close} this reader. try-with-resources is the correct usage. + * to correctly close out. + * + * As this is intended to be used both in an internal C* state as well as by external users looking to read CommitLogSegments, + * we allow construction to fail gracefully and indicate usability through {@link #isClosed()}. + */ +@NotThreadSafe +public class ResumableCommitLogReader implements AutoCloseable +{ + private static final Logger logger = LoggerFactory.getLogger(CommitLogReader.class); + + /** We hold a reference to these so we can re-use them for subsequent descriptor parsing on resumed reads */ + final File segmentFile; + final CommitLogDescriptor descriptor; + private final CommitLogReader commitLogReader; + final CommitLogReadHandler readHandler; + final boolean tolerateTruncation; + + /** Can be re-initialized if re-reading a reader w/compression enabled and we're at our known limit. */ + RandomAccessReader rawReader; + + /** We allow the users to determine whether or not the system should continue on various forms of read failure. As + * such, we allow resumable readers to be constructed even if they are unusable to the end-user. */ + boolean isClosed = false; + + /** Separate sentinel to indicate whether we have read to completion on our underlying file. Flagged by SegmentReader + * We use this during {@link #reBufferData()} to determine whether or not to recreate our underlying RAR in the compressed + * case. + */ + boolean readToExhaustion = false; + + /** Minimum position before which we completely skip CommitLogSegments */ + final CommitLogPosition minPosition; + + /** Sentinel used to limit reads */ + final int mutationLimit; + + /** We cache, snapshot, and revert position inside our {@link CommitLogSegmentReader.SegmentIterator#computeNext} calls + * to keep the user-facing API simple */ + @Nullable + Iterator activeIterator; + + @Nullable + private CommitLogSegmentReader segmentReader; + + /** Raw file offset at which we will stop iterating and processing mutations on a read */ + int offsetLimit = Integer.MIN_VALUE; + + public ResumableCommitLogReader(File commitLogSegment, CommitLogReadHandler readHandler) throws IOException, NullPointerException + { + this(commitLogSegment, readHandler, CommitLogPosition.NONE, CommitLogReader.ALL_MUTATIONS, true); + } + + public ResumableCommitLogReader(File commitLogSegment, + CommitLogReadHandler readHandler, + CommitLogPosition minPosition, + int mutationLimit, + boolean tolerateTruncation) throws IOException, NullPointerException + { + this.segmentFile = commitLogSegment; + this.commitLogReader = new CommitLogReader(); + this.readHandler = readHandler; + this.mutationLimit = mutationLimit; + this.minPosition = minPosition; + this.tolerateTruncation = tolerateTruncation; + + Pair, Integer> header = CommitLogReader.readCommitLogDescriptor(readHandler, + commitLogSegment, + tolerateTruncation); + // Totally fail out if we fail to parse this CommitLogSegment descriptor + if (!header.left.isPresent()) + throw new RuntimeException(MessageFormat.format("Failed to parse the CommitLogDescriptor from {0}", commitLogSegment)); + descriptor = header.left.get(); + + if (shouldSkipSegmentId(new File(descriptor.fileName()), descriptor, minPosition)) + { + close(); + } + else + { + try + { + this.rawReader = RandomAccessReader.open(commitLogSegment); + rawReader.seek(header.right); + + // This is where we grab and old open our handles if we succeed + segmentReader = CommitLogReader.getCommitLogSegmentReader(this); + if (segmentReader != null) + this.activeIterator = segmentReader.iterator(); + } + finally + { + if (segmentReader == null) + close(); + } + } + } + + /** + * Performs a partial CommitLogSegment read. Closes down this resumable reader on read error. + * + * @param readLimit How far to read into the file before stopping. + */ + public void readPartial(int readLimit) throws IOException + { + if (isClosed) + { + logger.warn("Attempted to use invalid ResumableCommitLogReader for file {}. Ignoring.", descriptor.fileName()); + return; + } + + if (readLimit <= offsetLimit) + { + logger.warn("Attempted to resume reading a commit log but used already read offset: {}", readLimit); + return; + } + offsetLimit = readLimit; + rebufferAndRead(); + } + + /** Reads to end of file from current cached offset. */ + public void readToCompletion() throws IOException + { + if (isClosed) + { + logger.warn("Attempted to use invalid ResumableCommitLogReader for file {}. Ignoring.", descriptor.fileName()); + return; + } + offsetLimit = CommitLogReader.READ_TO_END_OF_FILE; + rebufferAndRead(); + } + + public void close() + { + isClosed = true; + if (rawReader != null) + rawReader.close(); + segmentReader = null; + activeIterator = null; + } + + public boolean isClosed() + { + return isClosed; + } + + /** + * When we have compression enabled, RandomAccessReader's have CompressionMetadata to indicate the end of their file + * length. For our purposes, this means we have some difficulty in re-using previously constructed underlying buffers + * for decompression and reading, so if the underlying file length has changed because a file is actively being written + * and we've exhausted the current data we know about, we close out our RAR and construct a new one with the new + * {@link org.apache.cassandra.io.compress.CompressionMetadata}. While it would arguably be better to extend the + * hierarchy to have a rebuffering compressed segment, YAGNI for now. The added gc pressure from this + overhead + * on closing and re-opening RAR's should be restricted to non-node partial/resumed CL reading cases which we expect + * to have very different properties than critical path log replay on a running node, for example. + */ + private void reBufferData() throws FileNotFoundException + { + if (readToExhaustion) + { + long toSeek = rawReader.getPosition(); + this.rawReader.close(); + if (!segmentFile.exists()) + throw new FileNotFoundException(String.format("Attempting to reBufferData but underlying file cannot be found: {}", + segmentFile.getAbsolutePath())); + this.rawReader = RandomAccessReader.open(segmentFile); + this.rawReader.seek(toSeek); + } + else + { + rawReader.reBuffer(); + } + } + + /** Performs the read operation and closes down this reader on exception. */ + private void rebufferAndRead() throws RuntimeException, IOException + { + reBufferData(); + + try + { + commitLogReader.internalReadCommitLogSegment(this); + } + catch (RuntimeException | IOException e) + { + close(); + throw e; + } + } + + /** Any segment with id >= minPosition.segmentId is a candidate for read. */ + private boolean shouldSkipSegmentId(File file, CommitLogDescriptor desc, CommitLogPosition minPosition) + { + logger.debug("Reading {} (CL version {}, messaging version {}, compression {})", + file.getPath(), + desc.version, + desc.getMessagingVersion(), + desc.compression); + + if (minPosition.segmentId > desc.id) + { + logger.trace("Skipping read of fully-flushed {}", file); + return true; + } + return false; + } + + /** Flag to indicate how the {@link CommitLogSegmentReader.SegmentIterator} should behave on failure to compute next + * segments. + */ + boolean isPartial() + { + return offsetLimit != CommitLogReader.READ_TO_END_OF_FILE; + } + + @Override + public String toString() + { + return new StringBuilder() + .append("File: ").append(descriptor.fileName()).append(", ") + .append("minPos: ").append(minPosition).append(", ") + .append("offsetLimit: ").append(offsetLimit).append(", ") + .append("readerPos: ").append(rawReader.getPosition()).append(", ") + .append("activeIter: ").append(activeIterator) + .toString(); + } +} diff --git a/src/java/org/apache/cassandra/db/compaction/AbstractCompactionStrategy.java b/src/java/org/apache/cassandra/db/compaction/AbstractCompactionStrategy.java index 74c154d51905..ad494b15da06 100644 --- a/src/java/org/apache/cassandra/db/compaction/AbstractCompactionStrategy.java +++ b/src/java/org/apache/cassandra/db/compaction/AbstractCompactionStrategy.java @@ -21,8 +21,6 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableMap; -import com.google.common.base.Predicate; -import com.google.common.collect.Iterables; import org.apache.cassandra.db.Directories; import org.apache.cassandra.db.SerializationHeader; @@ -412,8 +410,8 @@ else if (CompactionController.getFullyExpiredSSTables(cfs, Collections.singleton ranges.add(new Range<>(overlap.first.getToken(), overlap.last.getToken())); long remainingKeys = keys - sstable.estimatedKeysForRanges(ranges); // next, calculate what percentage of columns we have within those keys - long columns = sstable.getEstimatedColumnCount().mean() * remainingKeys; - double remainingColumnsRatio = ((double) columns) / (sstable.getEstimatedColumnCount().count() * sstable.getEstimatedColumnCount().mean()); + long columns = sstable.getEstimatedCellPerPartitionCount().mean() * remainingKeys; + double remainingColumnsRatio = ((double) columns) / (sstable.getEstimatedCellPerPartitionCount().count() * sstable.getEstimatedCellPerPartitionCount().mean()); // return if we still expect to have droppable tombstones in rest of columns return remainingColumnsRatio * droppableRatio > tombstoneThreshold; diff --git a/src/java/org/apache/cassandra/db/compaction/CompactionController.java b/src/java/org/apache/cassandra/db/compaction/CompactionController.java index 59bba0a5cc8b..e1b0f3258359 100644 --- a/src/java/org/apache/cassandra/db/compaction/CompactionController.java +++ b/src/java/org/apache/cassandra/db/compaction/CompactionController.java @@ -19,26 +19,21 @@ import java.util.*; import java.util.function.LongPredicate; -import java.util.function.Predicate; - -import org.apache.cassandra.config.Config; -import org.apache.cassandra.db.Memtable; -import org.apache.cassandra.db.rows.UnfilteredRowIterator; import com.google.common.base.Predicates; import com.google.common.collect.Iterables; import com.google.common.util.concurrent.RateLimiter; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.apache.cassandra.config.Config; +import org.apache.cassandra.db.*; import org.apache.cassandra.db.partitions.Partition; +import org.apache.cassandra.db.rows.UnfilteredRowIterator; import org.apache.cassandra.io.sstable.format.SSTableReader; import org.apache.cassandra.io.util.FileDataInput; import org.apache.cassandra.io.util.FileUtils; import org.apache.cassandra.schema.CompactionParams.TombstoneOption; - -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import org.apache.cassandra.db.*; import org.apache.cassandra.utils.AlwaysPresentFilter; import org.apache.cassandra.utils.OverlapIterator; import org.apache.cassandra.utils.concurrent.Refs; @@ -48,13 +43,12 @@ /** * Manage compaction options. */ -public class CompactionController implements AutoCloseable +public class CompactionController extends AbstractCompactionController { private static final Logger logger = LoggerFactory.getLogger(CompactionController.class); private static final String NEVER_PURGE_TOMBSTONES_PROPERTY = Config.PROPERTY_PREFIX + "never_purge_tombstones"; static final boolean NEVER_PURGE_TOMBSTONES = Boolean.getBoolean(NEVER_PURGE_TOMBSTONES_PROPERTY); - public final ColumnFamilyStore cfs; private final boolean compactingRepaired; // note that overlapIterator and overlappingSSTables will be null if NEVER_PURGE_TOMBSTONES is set - this is a // good thing so that noone starts using them and thinks that if overlappingSSTables is empty, there @@ -64,11 +58,8 @@ public class CompactionController implements AutoCloseable private final Iterable compacting; private final RateLimiter limiter; private final long minTimestamp; - final TombstoneOption tombstoneOption; final Map openDataFiles = new HashMap<>(); - public final int gcBefore; - protected CompactionController(ColumnFamilyStore cfs, int maxValue) { this(cfs, null, maxValue); @@ -82,13 +73,10 @@ public CompactionController(ColumnFamilyStore cfs, Set compacting public CompactionController(ColumnFamilyStore cfs, Set compacting, int gcBefore, RateLimiter limiter, TombstoneOption tombstoneOption) { - assert cfs != null; - this.cfs = cfs; - this.gcBefore = gcBefore; + super(cfs, gcBefore, tombstoneOption); this.compacting = compacting; this.limiter = limiter; compactingRepaired = compacting != null && compacting.stream().allMatch(SSTableReader::isRepaired); - this.tombstoneOption = tombstoneOption; this.minTimestamp = compacting != null && !compacting.isEmpty() // check needed for test ? compacting.stream().mapToLong(SSTableReader::getMinTimestamp).min().getAsLong() : 0; @@ -246,16 +234,6 @@ public static Set getFullyExpiredSSTables(ColumnFamilyStore cfSto return getFullyExpiredSSTables(cfStore, compacting, overlapping, gcBefore, false); } - public String getKeyspace() - { - return cfs.keyspace.getName(); - } - - public String getColumnFamily() - { - return cfs.name; - } - /** * @param key * @return a predicate for whether tombstones marked for deletion at the given time for the given partition are @@ -263,6 +241,7 @@ public String getColumnFamily() * containing his partition and not participating in the compaction. This means there isn't any data in those * sstables that might still need to be suppressed by a tombstone at this timestamp. */ + @Override public LongPredicate getPurgeEvaluator(DecoratedKey key) { if (NEVER_PURGE_TOMBSTONES || !compactingRepaired() || cfs.getNeverPurgeTombstones()) diff --git a/src/java/org/apache/cassandra/db/compaction/CompactionIterator.java b/src/java/org/apache/cassandra/db/compaction/CompactionIterator.java index 1c56a87bc82d..789d1eeeb5c8 100644 --- a/src/java/org/apache/cassandra/db/compaction/CompactionIterator.java +++ b/src/java/org/apache/cassandra/db/compaction/CompactionIterator.java @@ -57,7 +57,7 @@ public class CompactionIterator extends CompactionInfo.Holder implements Unfilte private static final long UNFILTERED_TO_UPDATE_PROGRESS = 100; private final OperationType type; - private final CompactionController controller; + private final AbstractCompactionController controller; private final List scanners; private final ImmutableSet sstables; private final int nowInSec; @@ -77,13 +77,13 @@ public class CompactionIterator extends CompactionInfo.Holder implements Unfilte private final UnfilteredPartitionIterator compacted; private final ActiveCompactionsTracker activeCompactions; - public CompactionIterator(OperationType type, List scanners, CompactionController controller, int nowInSec, UUID compactionId) + public CompactionIterator(OperationType type, List scanners, AbstractCompactionController controller, int nowInSec, UUID compactionId) { this(type, scanners, controller, nowInSec, compactionId, ActiveCompactionsTracker.NOOP); } @SuppressWarnings("resource") // We make sure to close mergedIterator in close() and CompactionIterator is itself an AutoCloseable - public CompactionIterator(OperationType type, List scanners, CompactionController controller, int nowInSec, UUID compactionId, ActiveCompactionsTracker activeCompactions) + public CompactionIterator(OperationType type, List scanners, AbstractCompactionController controller, int nowInSec, UUID compactionId, ActiveCompactionsTracker activeCompactions) { this.controller = controller; this.type = type; @@ -259,14 +259,14 @@ public String toString() private class Purger extends PurgeFunction { - private final CompactionController controller; + private final AbstractCompactionController controller; private DecoratedKey currentKey; private LongPredicate purgeEvaluator; private long compactedUnfiltered; - private Purger(CompactionController controller, int nowInSec) + private Purger(AbstractCompactionController controller, int nowInSec) { super(nowInSec, controller.gcBefore, controller.compactingRepaired() ? Integer.MAX_VALUE : Integer.MIN_VALUE, controller.cfs.getCompactionStrategyManager().onlyPurgeRepairedTombstones(), @@ -510,10 +510,10 @@ private DeletionTime updateOpenDeletionTime(DeletionTime openDeletionTime, Unfil */ private static class GarbageSkipper extends Transformation { - final CompactionController controller; + final AbstractCompactionController controller; final boolean cellLevelGC; - private GarbageSkipper(CompactionController controller) + private GarbageSkipper(AbstractCompactionController controller) { this.controller = controller; cellLevelGC = controller.tombstoneOption == TombstoneOption.CELL; diff --git a/src/java/org/apache/cassandra/db/compaction/CompactionManager.java b/src/java/org/apache/cassandra/db/compaction/CompactionManager.java index d38770180c37..bb1d585e00b9 100644 --- a/src/java/org/apache/cassandra/db/compaction/CompactionManager.java +++ b/src/java/org/apache/cassandra/db/compaction/CompactionManager.java @@ -346,6 +346,7 @@ public BackgroundCompactionCandidate getBackgroundCompactionCandidate(ColumnFami @SuppressWarnings("resource") private AllSSTableOpStatus parallelAllSSTableOperation(final ColumnFamilyStore cfs, final OneSSTableOperation operation, int jobs, OperationType operationType) throws ExecutionException, InterruptedException { + logger.info("Starting {} for {}.{}", operationType, cfs.keyspace.getName(), cfs.getTableName()); List transactions = new ArrayList<>(); List> futures = new ArrayList<>(); try (LifecycleTransaction compacting = cfs.markAllCompacting(operationType)) @@ -387,6 +388,7 @@ public Object call() throws Exception } FBUtilities.waitOnFutures(futures); assert compacting.originals().isEmpty(); + logger.info("Finished {} for {}.{} successfully", operationType, cfs.keyspace.getName(), cfs.getTableName()); return AllSSTableOpStatus.SUCCESSFUL; } finally @@ -402,7 +404,7 @@ public Object call() throws Exception } Throwable fail = Throwables.close(null, transactions); if (fail != null) - logger.error("Failed to cleanup lifecycle transactions", fail); + logger.error("Failed to cleanup lifecycle transactions ({} for {}.{})", operationType, cfs.keyspace.getName(), cfs.getTableName(), fail); } } @@ -527,7 +529,34 @@ public AllSSTableOpStatus performCleanup(final ColumnFamilyStore cfStore, int jo public Iterable filterSSTables(LifecycleTransaction transaction) { List sortedSSTables = Lists.newArrayList(transaction.originals()); - Collections.sort(sortedSSTables, SSTableReader.sizeComparator); + Iterator sstableIter = sortedSSTables.iterator(); + int totalSSTables = 0; + int skippedSStables = 0; + while (sstableIter.hasNext()) + { + SSTableReader sstable = sstableIter.next(); + boolean needsCleanupFull = needsCleanup(sstable, fullRanges); + boolean needsCleanupTransient = needsCleanup(sstable, transientRanges); + //If there are no ranges for which the table needs cleanup either due to lack of intersection or lack + //of the table being repaired. + totalSSTables++; + if (!needsCleanupFull && (!needsCleanupTransient || !sstable.isRepaired())) + { + logger.debug("Skipping {} ([{}, {}]) for cleanup; all rows should be kept. Needs cleanup full ranges: {} Needs cleanup transient ranges: {} Repaired: {}", + sstable, + sstable.first.getToken(), + sstable.last.getToken(), + needsCleanupFull, + needsCleanupTransient, + sstable.isRepaired()); + sstableIter.remove(); + transaction.cancel(sstable); + skippedSStables++; + } + } + logger.info("Skipping cleanup for {}/{} sstables for {}.{} since they are fully contained in owned ranges (full ranges: {}, transient ranges: {})", + skippedSStables, totalSSTables, cfStore.keyspace.getName(), cfStore.getTableName(), fullRanges, transientRanges); + sortedSSTables.sort(SSTableReader.sizeComparator); return sortedSSTables; } @@ -817,14 +846,15 @@ public void performMaximal(final ColumnFamilyStore cfStore, boolean splitOutput) FBUtilities.waitOnFutures(submitMaximal(cfStore, getDefaultGcBefore(cfStore, FBUtilities.nowInSeconds()), splitOutput)); } + @SuppressWarnings("resource") // the tasks are executed in parallel on the executor, making sure that they get closed public List> submitMaximal(final ColumnFamilyStore cfStore, final int gcBefore, boolean splitOutput) { // here we compute the task off the compaction executor, so having that present doesn't // confuse runWithCompactionsDisabled -- i.e., we don't want to deadlock ourselves, waiting // for ourselves to finish/acknowledge cancellation before continuing. - final Collection tasks = cfStore.getCompactionStrategyManager().getMaximalTasks(gcBefore, splitOutput); + CompactionTasks tasks = cfStore.getCompactionStrategyManager().getMaximalTasks(gcBefore, splitOutput); - if (tasks == null) + if (tasks.isEmpty()) return Collections.emptyList(); List> futures = new ArrayList<>(); @@ -850,42 +880,42 @@ protected void runMayThrow() if (nonEmptyTasks > 1) logger.info("Major compaction will not result in a single sstable - repaired and unrepaired data is kept separate and compaction runs per data_file_directory."); - return futures; } public void forceCompactionForTokenRange(ColumnFamilyStore cfStore, Collection> ranges) { - final Collection tasks = cfStore.runWithCompactionsDisabled(() -> - { - Collection sstables = sstablesInBounds(cfStore, ranges); - if (sstables == null || sstables.isEmpty()) - { - logger.debug("No sstables found for the provided token range"); - return null; - } - return cfStore.getCompactionStrategyManager().getUserDefinedTasks(sstables, getDefaultGcBefore(cfStore, FBUtilities.nowInSeconds())); - }, (sstable) -> new Bounds<>(sstable.first.getToken(), sstable.last.getToken()).intersects(ranges), false, false, false); - - if (tasks == null) - return; - - Runnable runnable = new WrappedRunnable() - { - protected void runMayThrow() + Callable taskCreator = () -> { + Collection sstables = sstablesInBounds(cfStore, ranges); + if (sstables == null || sstables.isEmpty()) { - for (AbstractCompactionTask task : tasks) - if (task != null) - task.execute(active); + logger.debug("No sstables found for the provided token range"); + return CompactionTasks.empty(); } + return cfStore.getCompactionStrategyManager().getUserDefinedTasks(sstables, getDefaultGcBefore(cfStore, FBUtilities.nowInSeconds())); }; - if (executor.isShutdown()) + try (CompactionTasks tasks = cfStore.runWithCompactionsDisabled(taskCreator, + (sstable) -> new Bounds<>(sstable.first.getToken(), sstable.last.getToken()).intersects(ranges), + false, + false, + false)) { - logger.info("Compaction executor has shut down, not submitting task"); - return; + if (tasks.isEmpty()) + return; + + Runnable runnable = new WrappedRunnable() + { + protected void runMayThrow() + { + for (AbstractCompactionTask task : tasks) + if (task != null) + task.execute(active); + } + }; + + FBUtilities.waitOnFuture(executor.submitIfRunning(runnable, "force compaction for token range")); } - FBUtilities.waitOnFuture(executor.submit(runnable)); } private static Collection sstablesInBounds(ColumnFamilyStore cfs, Collection> tokenRangeCollection) @@ -990,7 +1020,7 @@ public Future submitUserDefined(final ColumnFamilyStore cfs, final Collection { Runnable runnable = new WrappedRunnable() { - protected void runMayThrow() + protected void runMayThrow() throws Exception { // look up the sstables now that we're on the compaction executor, so we don't try to re-compact // something that was already being compacted earlier. @@ -1015,11 +1045,13 @@ protected void runMayThrow() } else { - List tasks = cfs.getCompactionStrategyManager().getUserDefinedTasks(sstables, gcBefore); - for (AbstractCompactionTask task : tasks) + try (CompactionTasks tasks = cfs.getCompactionStrategyManager().getUserDefinedTasks(sstables, gcBefore)) { - if (task != null) - task.execute(active); + for (AbstractCompactionTask task : tasks) + { + if (task != null) + task.execute(active); + } } } } @@ -1171,16 +1203,7 @@ private void doCleanupOne(final ColumnFamilyStore cfs, { txn.obsoleteOriginals(); txn.finish(); - return; - } - - boolean needsCleanupFull = needsCleanup(sstable, fullRanges); - boolean needsCleanupTransient = needsCleanup(sstable, transientRanges); - //If there are no ranges for which the table needs cleanup either due to lack of intersection or lack - //of the table being repaired. - if (!needsCleanupFull && (!needsCleanupTransient || !sstable.isRepaired())) - { - logger.trace("Skipping {} for cleanup; all rows should be kept. Needs cleanup full ranges: {} Needs cleanup transient ranges: {} Repaired: {}", sstable, needsCleanupFull, needsCleanupTransient, sstable.isRepaired()); + logger.info("SSTable {} ([{}, {}]) does not intersect the owned ranges ({}), dropping it", sstable, sstable.first.getToken(), sstable.last.getToken(), allRanges); return; } diff --git a/src/java/org/apache/cassandra/db/compaction/CompactionStrategyManager.java b/src/java/org/apache/cassandra/db/compaction/CompactionStrategyManager.java index b97864185965..fd4dbeb7d8ca 100644 --- a/src/java/org/apache/cassandra/db/compaction/CompactionStrategyManager.java +++ b/src/java/org/apache/cassandra/db/compaction/CompactionStrategyManager.java @@ -29,7 +29,6 @@ import java.util.List; import java.util.Set; import java.util.UUID; -import java.util.concurrent.Callable; import java.util.concurrent.locks.ReentrantReadWriteLock; import java.util.function.Supplier; import java.util.stream.Collectors; @@ -956,34 +955,27 @@ private void validateForCompaction(Iterable input) } } - public Collection getMaximalTasks(final int gcBefore, final boolean splitOutput) + public CompactionTasks getMaximalTasks(final int gcBefore, final boolean splitOutput) { maybeReloadDiskBoundaries(); // runWithCompactionsDisabled cancels active compactions and disables them, then we are able // to make the repaired/unrepaired strategies mark their own sstables as compacting. Once the // sstables are marked the compactions are re-enabled - return cfs.runWithCompactionsDisabled(new Callable>() - { - @Override - public Collection call() + return cfs.runWithCompactionsDisabled(() -> { + List tasks = new ArrayList<>(); + readLock.lock(); + try { - List tasks = new ArrayList<>(); - readLock.lock(); - try - { - for (AbstractStrategyHolder holder : holders) - { - tasks.addAll(holder.getMaximalTasks(gcBefore, splitOutput)); - } - } - finally + for (AbstractStrategyHolder holder : holders) { - readLock.unlock(); + tasks.addAll(holder.getMaximalTasks(gcBefore, splitOutput)); } - if (tasks.isEmpty()) - return null; - return tasks; } + finally + { + readLock.unlock(); + } + return CompactionTasks.create(tasks); }, false, false); } @@ -996,7 +988,7 @@ public Collection call() * @param gcBefore gc grace period, throw away tombstones older than this * @return a list of compaction tasks corresponding to the sstables requested */ - public List getUserDefinedTasks(Collection sstables, int gcBefore) + public CompactionTasks getUserDefinedTasks(Collection sstables, int gcBefore) { maybeReloadDiskBoundaries(); List ret = new ArrayList<>(); @@ -1008,7 +1000,7 @@ public List getUserDefinedTasks(Collection implements AutoCloseable +{ + @SuppressWarnings("resource") + private static final CompactionTasks EMPTY = new CompactionTasks(Collections.emptyList()); + + private final Collection tasks; + + private CompactionTasks(Collection tasks) + { + this.tasks = tasks; + } + + public static CompactionTasks create(Collection tasks) + { + if (tasks == null || tasks.isEmpty()) + return EMPTY; + return new CompactionTasks(tasks); + } + + public static CompactionTasks empty() + { + return EMPTY; + } + + public Iterator iterator() + { + return tasks.iterator(); + } + + public int size() + { + return tasks.size(); + } + + public void close() + { + try + { + FBUtilities.closeAll(tasks.stream().map(task -> task.transaction).collect(Collectors.toList())); + } + catch (Exception e) + { + throw new RuntimeException(e); + } + } +} diff --git a/src/java/org/apache/cassandra/db/context/CounterContext.java b/src/java/org/apache/cassandra/db/context/CounterContext.java index 29dc3f087cd4..01c2f1d65dd0 100644 --- a/src/java/org/apache/cassandra/db/context/CounterContext.java +++ b/src/java/org/apache/cassandra/db/context/CounterContext.java @@ -629,7 +629,7 @@ public ByteBuffer markLocalToBeCleared(ByteBuffer context) ByteBuffer marked = ByteBuffer.allocate(context.remaining()); marked.putShort(marked.position(), (short) (count * -1)); - ByteBufferUtil.arrayCopy(context, + ByteBufferUtil.copyBytes(context, context.position() + HEADER_SIZE_LENGTH, marked, marked.position() + HEADER_SIZE_LENGTH, @@ -668,7 +668,7 @@ public ByteBuffer clearAllLocal(ByteBuffer context) cleared.putShort(cleared.position() + HEADER_SIZE_LENGTH + i * HEADER_ELT_LENGTH, globalShardIndexes.get(i)); int origHeaderLength = headerLength(context); - ByteBufferUtil.arrayCopy(context, + ByteBufferUtil.copyBytes(context, context.position() + origHeaderLength, cleared, cleared.position() + headerLength(cleared), diff --git a/src/java/org/apache/cassandra/db/filter/ColumnSubselection.java b/src/java/org/apache/cassandra/db/filter/ColumnSubselection.java index ddc7b1c2cc21..d0cc5143a5df 100644 --- a/src/java/org/apache/cassandra/db/filter/ColumnSubselection.java +++ b/src/java/org/apache/cassandra/db/filter/ColumnSubselection.java @@ -26,6 +26,7 @@ import org.apache.cassandra.db.marshal.CollectionType; import org.apache.cassandra.db.marshal.UTF8Type; import org.apache.cassandra.db.rows.CellPath; +import org.apache.cassandra.exceptions.UnknownColumnException; import org.apache.cassandra.io.util.DataInputPlus; import org.apache.cassandra.io.util.DataOutputPlus; import org.apache.cassandra.schema.ColumnMetadata; @@ -196,7 +197,7 @@ public ColumnSubselection deserialize(DataInputPlus in, int version, TableMetada // deserialization. The column will be ignore later on anyway. column = metadata.getDroppedColumn(name); if (column == null) - throw new RuntimeException("Unknown column " + UTF8Type.instance.getString(name) + " during deserialization"); + throw new UnknownColumnException("Unknown column " + UTF8Type.instance.getString(name) + " during deserialization"); } Kind kind = Kind.values()[in.readUnsignedByte()]; diff --git a/src/java/org/apache/cassandra/db/filter/DataLimits.java b/src/java/org/apache/cassandra/db/filter/DataLimits.java index 50baf2170add..3a766e0c2aa8 100644 --- a/src/java/org/apache/cassandra/db/filter/DataLimits.java +++ b/src/java/org/apache/cassandra/db/filter/DataLimits.java @@ -69,6 +69,12 @@ public UnfilteredRowIterator filter(UnfilteredRowIterator iter, { return iter; } + + @Override + public PartitionIterator filter(PartitionIterator iter, int nowInSec, boolean countPartitionsWithOnlyStaticData, boolean enforceStrictLiveness) + { + return iter; + } }; // We currently deal with distinct queries by querying full partitions but limiting the result at 1 row per @@ -440,7 +446,7 @@ public float estimateTotalResults(ColumnFamilyStore cfs) { // TODO: we should start storing stats on the number of rows (instead of the number of cells, which // is what getMeanColumns returns) - float rowsPerPartition = ((float) cfs.getMeanColumns()) / cfs.metadata().regularColumns().size(); + float rowsPerPartition = ((float) cfs.getMeanEstimatedCellPerPartitionCount()) / cfs.metadata().regularColumns().size(); return rowsPerPartition * (cfs.estimateKeys()); } diff --git a/src/java/org/apache/cassandra/db/marshal/CompositeType.java b/src/java/org/apache/cassandra/db/marshal/CompositeType.java index ac4c69f7a6d2..e3423ff58008 100644 --- a/src/java/org/apache/cassandra/db/marshal/CompositeType.java +++ b/src/java/org/apache/cassandra/db/marshal/CompositeType.java @@ -360,7 +360,7 @@ public static ByteBuffer build(boolean isStatic, ByteBuffer... buffers) { ByteBufferUtil.writeShortLength(out, bb.remaining()); int toCopy = bb.remaining(); - ByteBufferUtil.arrayCopy(bb, bb.position(), out, out.position(), toCopy); + ByteBufferUtil.copyBytes(bb, bb.position(), out, out.position(), toCopy); out.position(out.position() + toCopy); out.put((byte) 0); } diff --git a/src/java/org/apache/cassandra/db/marshal/DecimalType.java b/src/java/org/apache/cassandra/db/marshal/DecimalType.java index b98bf009cdfc..110dc0e924e6 100644 --- a/src/java/org/apache/cassandra/db/marshal/DecimalType.java +++ b/src/java/org/apache/cassandra/db/marshal/DecimalType.java @@ -20,6 +20,7 @@ import java.math.BigDecimal; import java.math.BigInteger; import java.math.MathContext; +import java.math.RoundingMode; import java.nio.ByteBuffer; import org.apache.cassandra.cql3.CQL3Type; @@ -34,6 +35,10 @@ public class DecimalType extends NumberType { public static final DecimalType instance = new DecimalType(); + private static final int MIN_SCALE = 32; + private static final int MIN_SIGNIFICANT_DIGITS = MIN_SCALE; + private static final int MAX_SCALE = 1000; + private static final MathContext MAX_PRECISION = new MathContext(10000); DecimalType() {super(ComparisonType.CUSTOM);} // singleton @@ -142,27 +147,41 @@ protected BigDecimal toBigDecimal(ByteBuffer value) public ByteBuffer add(NumberType leftType, ByteBuffer left, NumberType rightType, ByteBuffer right) { - return decompose(leftType.toBigDecimal(left).add(rightType.toBigDecimal(right), MathContext.DECIMAL128)); + return decompose(leftType.toBigDecimal(left).add(rightType.toBigDecimal(right), MAX_PRECISION)); } public ByteBuffer substract(NumberType leftType, ByteBuffer left, NumberType rightType, ByteBuffer right) { - return decompose(leftType.toBigDecimal(left).subtract(rightType.toBigDecimal(right), MathContext.DECIMAL128)); + return decompose(leftType.toBigDecimal(left).subtract(rightType.toBigDecimal(right), MAX_PRECISION)); } public ByteBuffer multiply(NumberType leftType, ByteBuffer left, NumberType rightType, ByteBuffer right) { - return decompose(leftType.toBigDecimal(left).multiply(rightType.toBigDecimal(right), MathContext.DECIMAL128)); + return decompose(leftType.toBigDecimal(left).multiply(rightType.toBigDecimal(right), MAX_PRECISION)); } public ByteBuffer divide(NumberType leftType, ByteBuffer left, NumberType rightType, ByteBuffer right) { - return decompose(leftType.toBigDecimal(left).divide(rightType.toBigDecimal(right), MathContext.DECIMAL128)); + BigDecimal leftOperand = leftType.toBigDecimal(left); + BigDecimal rightOperand = rightType.toBigDecimal(right); + + // Predict position of first significant digit in the quotient. + // Note: it is possible to improve prediction accuracy by comparing first significant digits in operands + // but it requires additional computations so this step is omitted + int quotientFirstDigitPos = (leftOperand.precision() - leftOperand.scale()) - (rightOperand.precision() - rightOperand.scale()); + + int scale = MIN_SIGNIFICANT_DIGITS - quotientFirstDigitPos; + scale = Math.max(scale, leftOperand.scale()); + scale = Math.max(scale, rightOperand.scale()); + scale = Math.max(scale, MIN_SCALE); + scale = Math.min(scale, MAX_SCALE); + + return decompose(leftOperand.divide(rightOperand, scale, RoundingMode.HALF_UP).stripTrailingZeros()); } public ByteBuffer mod(NumberType leftType, ByteBuffer left, NumberType rightType, ByteBuffer right) { - return decompose(leftType.toBigDecimal(left).remainder(rightType.toBigDecimal(right), MathContext.DECIMAL128)); + return decompose(leftType.toBigDecimal(left).remainder(rightType.toBigDecimal(right))); } public ByteBuffer negate(ByteBuffer input) diff --git a/src/java/org/apache/cassandra/db/monitoring/ApproximateTime.java b/src/java/org/apache/cassandra/db/monitoring/ApproximateTime.java deleted file mode 100644 index cc4b41041ce3..000000000000 --- a/src/java/org/apache/cassandra/db/monitoring/ApproximateTime.java +++ /dev/null @@ -1,61 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.cassandra.db.monitoring; - -import java.util.concurrent.TimeUnit; - -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import org.apache.cassandra.concurrent.ScheduledExecutors; -import org.apache.cassandra.config.Config; - -/** - * This is an approximation of System.currentTimeInMillis(). It updates its - * time value at periodic intervals of CHECK_INTERVAL_MS milliseconds - * (currently 10 milliseconds by default). It can be used as a faster alternative - * to System.currentTimeInMillis() every time an imprecision of a few milliseconds - * can be accepted. - */ -public class ApproximateTime -{ - private static final Logger logger = LoggerFactory.getLogger(ApproximateTime.class); - private static final int CHECK_INTERVAL_MS = Math.max(5, Integer.parseInt(System.getProperty(Config.PROPERTY_PREFIX + "approximate_time_precision_ms", "10"))); - - private static volatile long time = System.currentTimeMillis(); - static - { - logger.info("Scheduling approximate time-check task with a precision of {} milliseconds", CHECK_INTERVAL_MS); - ScheduledExecutors.scheduledFastTasks.scheduleWithFixedDelay(() -> time = System.currentTimeMillis(), - CHECK_INTERVAL_MS, - CHECK_INTERVAL_MS, - TimeUnit.MILLISECONDS); - } - - public static long currentTimeMillis() - { - return time; - } - - public static long precision() - { - return 2 * CHECK_INTERVAL_MS; - } - -} diff --git a/src/java/org/apache/cassandra/db/monitoring/Monitorable.java b/src/java/org/apache/cassandra/db/monitoring/Monitorable.java index c9bf94e08f46..10bd10438aa5 100644 --- a/src/java/org/apache/cassandra/db/monitoring/Monitorable.java +++ b/src/java/org/apache/cassandra/db/monitoring/Monitorable.java @@ -21,9 +21,9 @@ public interface Monitorable { String name(); - long constructionTime(); - long timeout(); - long slowTimeout(); + long creationTimeNanos(); + long timeoutNanos(); + long slowTimeoutNanos(); boolean isInProgress(); boolean isAborted(); diff --git a/src/java/org/apache/cassandra/db/monitoring/MonitorableImpl.java b/src/java/org/apache/cassandra/db/monitoring/MonitorableImpl.java index 48c815270498..a6e7947b23f1 100644 --- a/src/java/org/apache/cassandra/db/monitoring/MonitorableImpl.java +++ b/src/java/org/apache/cassandra/db/monitoring/MonitorableImpl.java @@ -18,13 +18,15 @@ package org.apache.cassandra.db.monitoring; +import static org.apache.cassandra.utils.MonotonicClock.approxTime; + public abstract class MonitorableImpl implements Monitorable { private MonitoringState state; private boolean isSlow; - private long constructionTime = -1; - private long timeout; - private long slowTimeout; + private long approxCreationTimeNanos = -1; + private long timeoutNanos; + private long slowTimeoutNanos; private boolean isCrossNode; protected MonitorableImpl() @@ -38,23 +40,23 @@ protected MonitorableImpl() * is too complex, it would require passing new parameters to all serializers * or specializing the serializers to accept these message properties. */ - public void setMonitoringTime(long constructionTime, boolean isCrossNode, long timeout, long slowTimeout) + public void setMonitoringTime(long approxCreationTimeNanos, boolean isCrossNode, long timeoutNanos, long slowTimeoutNanos) { - assert constructionTime >= 0; - this.constructionTime = constructionTime; + assert approxCreationTimeNanos >= 0; + this.approxCreationTimeNanos = approxCreationTimeNanos; this.isCrossNode = isCrossNode; - this.timeout = timeout; - this.slowTimeout = slowTimeout; + this.timeoutNanos = timeoutNanos; + this.slowTimeoutNanos = slowTimeoutNanos; } - public long constructionTime() + public long creationTimeNanos() { - return constructionTime; + return approxCreationTimeNanos; } - public long timeout() + public long timeoutNanos() { - return timeout; + return timeoutNanos; } public boolean isCrossNode() @@ -62,9 +64,9 @@ public boolean isCrossNode() return isCrossNode; } - public long slowTimeout() + public long slowTimeoutNanos() { - return slowTimeout; + return slowTimeoutNanos; } public boolean isInProgress() @@ -95,8 +97,8 @@ public boolean abort() { if (state == MonitoringState.IN_PROGRESS) { - if (constructionTime >= 0) - MonitoringTask.addFailedOperation(this, ApproximateTime.currentTimeMillis()); + if (approxCreationTimeNanos >= 0) + MonitoringTask.addFailedOperation(this, approxTime.now()); state = MonitoringState.ABORTED; return true; @@ -109,8 +111,8 @@ public boolean complete() { if (state == MonitoringState.IN_PROGRESS) { - if (isSlow && slowTimeout > 0 && constructionTime >= 0) - MonitoringTask.addSlowOperation(this, ApproximateTime.currentTimeMillis()); + if (isSlow && slowTimeoutNanos > 0 && approxCreationTimeNanos >= 0) + MonitoringTask.addSlowOperation(this, approxTime.now()); state = MonitoringState.COMPLETED; return true; @@ -121,15 +123,15 @@ public boolean complete() private void check() { - if (constructionTime < 0 || state != MonitoringState.IN_PROGRESS) + if (approxCreationTimeNanos < 0 || state != MonitoringState.IN_PROGRESS) return; - long elapsed = ApproximateTime.currentTimeMillis() - constructionTime; + long minElapsedNanos = (approxTime.now() - approxCreationTimeNanos) - approxTime.error(); - if (elapsed >= slowTimeout && !isSlow) + if (minElapsedNanos >= slowTimeoutNanos && !isSlow) isSlow = true; - if (elapsed >= timeout) + if (minElapsedNanos >= timeoutNanos) abort(); } } diff --git a/src/java/org/apache/cassandra/db/monitoring/MonitoringTask.java b/src/java/org/apache/cassandra/db/monitoring/MonitoringTask.java index 94260422798e..0f8555f17aa3 100644 --- a/src/java/org/apache/cassandra/db/monitoring/MonitoringTask.java +++ b/src/java/org/apache/cassandra/db/monitoring/MonitoringTask.java @@ -39,6 +39,8 @@ import org.apache.cassandra.utils.NoSpamLogger; import static java.lang.System.getProperty; +import static java.util.concurrent.TimeUnit.NANOSECONDS; +import static org.apache.cassandra.utils.MonotonicClock.approxTime; /** * A task for monitoring in progress operations, currently only read queries, and aborting them if they time out. @@ -68,7 +70,7 @@ class MonitoringTask private final ScheduledFuture reportingTask; private final OperationsQueue failedOperationsQueue; private final OperationsQueue slowOperationsQueue; - private long lastLogTime; + private long approxLastLogTimeNanos; @VisibleForTesting @@ -88,10 +90,10 @@ private MonitoringTask(int reportIntervalMillis, int maxOperations) this.failedOperationsQueue = new OperationsQueue(maxOperations); this.slowOperationsQueue = new OperationsQueue(maxOperations); - this.lastLogTime = ApproximateTime.currentTimeMillis(); + this.approxLastLogTimeNanos = approxTime.now(); logger.info("Scheduling monitoring task with report interval of {} ms, max operations {}", reportIntervalMillis, maxOperations); - this.reportingTask = ScheduledExecutors.scheduledTasks.scheduleWithFixedDelay(() -> logOperations(ApproximateTime.currentTimeMillis()), + this.reportingTask = ScheduledExecutors.scheduledTasks.scheduleWithFixedDelay(() -> logOperations(approxTime.now()), reportIntervalMillis, reportIntervalMillis, TimeUnit.MILLISECONDS); @@ -102,14 +104,14 @@ public void cancel() reportingTask.cancel(false); } - static void addFailedOperation(Monitorable operation, long now) + static void addFailedOperation(Monitorable operation, long nowNanos) { - instance.failedOperationsQueue.offer(new FailedOperation(operation, now)); + instance.failedOperationsQueue.offer(new FailedOperation(operation, nowNanos)); } - static void addSlowOperation(Monitorable operation, long now) + static void addSlowOperation(Monitorable operation, long nowNanos) { - instance.slowOperationsQueue.offer(new SlowOperation(operation, now)); + instance.slowOperationsQueue.offer(new SlowOperation(operation, nowNanos)); } @VisibleForTesting @@ -131,27 +133,27 @@ private List getLogMessages(AggregatedOperations operations) } @VisibleForTesting - private void logOperations(long now) + private void logOperations(long approxCurrentTimeNanos) { - logSlowOperations(now); - logFailedOperations(now); + logSlowOperations(approxCurrentTimeNanos); + logFailedOperations(approxCurrentTimeNanos); - lastLogTime = now; + approxLastLogTimeNanos = approxCurrentTimeNanos; } @VisibleForTesting - boolean logFailedOperations(long now) + boolean logFailedOperations(long nowNanos) { AggregatedOperations failedOperations = failedOperationsQueue.popOperations(); if (!failedOperations.isEmpty()) { - long elapsed = now - lastLogTime; + long elapsedNanos = nowNanos - approxLastLogTimeNanos; noSpamLogger.warn("Some operations timed out, details available at debug level (debug.log)"); if (logger.isDebugEnabled()) logger.debug("{} operations timed out in the last {} msecs:{}{}", failedOperations.num(), - elapsed, + NANOSECONDS.toMillis(elapsedNanos), LINE_SEPARATOR, failedOperations.getLogMessage()); return true; @@ -161,18 +163,18 @@ boolean logFailedOperations(long now) } @VisibleForTesting - boolean logSlowOperations(long now) + boolean logSlowOperations(long approxCurrentTimeNanos) { AggregatedOperations slowOperations = slowOperationsQueue.popOperations(); if (!slowOperations.isEmpty()) { - long elapsed = now - lastLogTime; + long approxElapsedNanos = approxCurrentTimeNanos - approxLastLogTimeNanos; noSpamLogger.info("Some operations were slow, details available at debug level (debug.log)"); if (logger.isDebugEnabled()) logger.debug("{} operations were slow in the last {} msecs:{}{}", slowOperations.num(), - elapsed, + NANOSECONDS.toMillis(approxElapsedNanos), LINE_SEPARATOR, slowOperations.getLogMessage()); return true; @@ -314,7 +316,7 @@ protected abstract static class Operation int numTimesReported; /** The total time spent by this operation */ - long totalTime; + long totalTimeNanos; /** The maximum time spent by this operation */ long maxTime; @@ -326,13 +328,13 @@ protected abstract static class Operation * this is set lazily as it takes time to build the query CQL */ private String name; - Operation(Monitorable operation, long failedAt) + Operation(Monitorable operation, long failedAtNanos) { this.operation = operation; numTimesReported = 1; - totalTime = failedAt - operation.constructionTime(); - minTime = totalTime; - maxTime = totalTime; + totalTimeNanos = failedAtNanos - operation.creationTimeNanos(); + minTime = totalTimeNanos; + maxTime = totalTimeNanos; } public String name() @@ -345,7 +347,7 @@ public String name() void add(Operation operation) { numTimesReported++; - totalTime += operation.totalTime; + totalTimeNanos += operation.totalTimeNanos; maxTime = Math.max(maxTime, operation.maxTime); minTime = Math.min(minTime, operation.minTime); } @@ -358,9 +360,9 @@ void add(Operation operation) */ private final static class FailedOperation extends Operation { - FailedOperation(Monitorable operation, long failedAt) + FailedOperation(Monitorable operation, long failedAtNanos) { - super(operation, failedAt); + super(operation, failedAtNanos); } public String getLogMessage() @@ -368,17 +370,17 @@ public String getLogMessage() if (numTimesReported == 1) return String.format("<%s>, total time %d msec, timeout %d %s", name(), - totalTime, - operation.timeout(), + NANOSECONDS.toMillis(totalTimeNanos), + NANOSECONDS.toMillis(operation.timeoutNanos()), operation.isCrossNode() ? "msec/cross-node" : "msec"); else return String.format("<%s> timed out %d times, avg/min/max %d/%d/%d msec, timeout %d %s", name(), numTimesReported, - totalTime / numTimesReported, - minTime, - maxTime, - operation.timeout(), + NANOSECONDS.toMillis(totalTimeNanos / numTimesReported), + NANOSECONDS.toMillis(minTime), + NANOSECONDS.toMillis(maxTime), + NANOSECONDS.toMillis(operation.timeoutNanos()), operation.isCrossNode() ? "msec/cross-node" : "msec"); } } @@ -398,17 +400,17 @@ public String getLogMessage() if (numTimesReported == 1) return String.format("<%s>, time %d msec - slow timeout %d %s", name(), - totalTime, - operation.slowTimeout(), + NANOSECONDS.toMillis(totalTimeNanos), + NANOSECONDS.toMillis(operation.slowTimeoutNanos()), operation.isCrossNode() ? "msec/cross-node" : "msec"); else return String.format("<%s>, was slow %d times: avg/min/max %d/%d/%d msec - slow timeout %d %s", name(), numTimesReported, - totalTime / numTimesReported, - minTime, - maxTime, - operation.slowTimeout(), + NANOSECONDS.toMillis(totalTimeNanos/ numTimesReported), + NANOSECONDS.toMillis(minTime), + NANOSECONDS.toMillis(maxTime), + NANOSECONDS.toMillis(operation.slowTimeoutNanos()), operation.isCrossNode() ? "msec/cross-node" : "msec"); } } diff --git a/src/java/org/apache/cassandra/db/partitions/PartitionIterators.java b/src/java/org/apache/cassandra/db/partitions/PartitionIterators.java index bed0958b8e2f..74a61d6feadd 100644 --- a/src/java/org/apache/cassandra/db/partitions/PartitionIterators.java +++ b/src/java/org/apache/cassandra/db/partitions/PartitionIterators.java @@ -20,6 +20,7 @@ import java.util.*; import org.apache.cassandra.db.EmptyIterators; +import org.apache.cassandra.db.transform.FilteredPartitions; import org.apache.cassandra.db.transform.MorePartitions; import org.apache.cassandra.db.transform.Transformation; import org.apache.cassandra.utils.AbstractIterator; @@ -66,7 +67,7 @@ public static PartitionIterator concat(final List iterators) class Extend implements MorePartitions { - int i = 1; + int i = 0; public PartitionIterator moreContents() { if (i >= iterators.size()) @@ -74,7 +75,8 @@ public PartitionIterator moreContents() return iterators.get(i++); } } - return MorePartitions.extend(iterators.get(0), new Extend()); + + return MorePartitions.extend(EmptyIterators.partition(), new Extend()); } public static PartitionIterator singletonIterator(RowIterator iterator) diff --git a/src/java/org/apache/cassandra/db/repair/PendingAntiCompaction.java b/src/java/org/apache/cassandra/db/repair/PendingAntiCompaction.java index fac164d006a5..85d262566290 100644 --- a/src/java/org/apache/cassandra/db/repair/PendingAntiCompaction.java +++ b/src/java/org/apache/cassandra/db/repair/PendingAntiCompaction.java @@ -349,7 +349,7 @@ public ListenableFuture run() List> tasks = new ArrayList<>(tables.size()); for (ColumnFamilyStore cfs : tables) { - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); ListenableFutureTask task = ListenableFutureTask.create(getAcquisitionCallable(cfs, tokenRanges.ranges(), prsId, acquireRetrySeconds, acquireSleepMillis)); executor.submit(task); tasks.add(task); diff --git a/src/java/org/apache/cassandra/db/streaming/CassandraCompressedStreamReader.java b/src/java/org/apache/cassandra/db/streaming/CassandraCompressedStreamReader.java index eb993ff0f8d1..c362d1174329 100644 --- a/src/java/org/apache/cassandra/db/streaming/CassandraCompressedStreamReader.java +++ b/src/java/org/apache/cassandra/db/streaming/CassandraCompressedStreamReader.java @@ -99,8 +99,9 @@ public SSTableMultiWriter read(DataInputPlus inputPlus) throws IOException // when compressed, report total bytes of compressed chunks read since remoteFile.size is the sum of chunks transferred session.progress(filename, ProgressInfo.Direction.IN, cis.getTotalCompressedBytesRead(), totalSize); } + assert in.getBytesRead() == sectionLength; } - logger.debug("[Stream #{}] Finished receiving file #{} from {} readBytes = {}, totalSize = {}", session.planId(), fileSeqNum, + logger.trace("[Stream #{}] Finished receiving file #{} from {} readBytes = {}, totalSize = {}", session.planId(), fileSeqNum, session.peer, FBUtilities.prettyPrintMemory(cis.getTotalCompressedBytesRead()), FBUtilities.prettyPrintMemory(totalSize)); return writer; } diff --git a/src/java/org/apache/cassandra/db/streaming/CassandraCompressedStreamWriter.java b/src/java/org/apache/cassandra/db/streaming/CassandraCompressedStreamWriter.java index 3b971f885942..efbccdcf25b3 100644 --- a/src/java/org/apache/cassandra/db/streaming/CassandraCompressedStreamWriter.java +++ b/src/java/org/apache/cassandra/db/streaming/CassandraCompressedStreamWriter.java @@ -31,8 +31,7 @@ import org.apache.cassandra.io.sstable.format.SSTableReader; import org.apache.cassandra.io.util.ChannelProxy; import org.apache.cassandra.io.util.DataOutputStreamPlus; -import org.apache.cassandra.io.util.FileUtils; -import org.apache.cassandra.net.async.ByteBufDataOutputStreamPlus; +import org.apache.cassandra.net.AsyncStreamingOutputPlus; import org.apache.cassandra.streaming.ProgressInfo; import org.apache.cassandra.streaming.StreamSession; import org.apache.cassandra.utils.FBUtilities; @@ -55,10 +54,9 @@ public CassandraCompressedStreamWriter(SSTableReader sstable, Collection { + ByteBuffer outBuffer = bufferSupplier.get(toTransfer); + long read = fc.read(outBuffer, position); + assert read == toTransfer : String.format("could not read required number of bytes from file to be streamed: read %d bytes, wanted %d bytes", read, toTransfer); outBuffer.flip(); - output.writeToChannel(outBuffer); - } - catch (IOException e) - { - FileUtils.clean(outBuffer); - throw e; - } - - bytesTransferred += lastWrite; - progress += lastWrite; + }, limiter); + + bytesTransferred += toTransfer; + progress += toTransfer; session.progress(sstable.descriptor.filenameFor(Component.DATA), ProgressInfo.Direction.OUT, progress, totalSize); } } diff --git a/src/java/org/apache/cassandra/db/streaming/CassandraEntireSSTableStreamWriter.java b/src/java/org/apache/cassandra/db/streaming/CassandraEntireSSTableStreamWriter.java index 7a20110d0bf3..401b20ed56ad 100644 --- a/src/java/org/apache/cassandra/db/streaming/CassandraEntireSSTableStreamWriter.java +++ b/src/java/org/apache/cassandra/db/streaming/CassandraEntireSSTableStreamWriter.java @@ -27,7 +27,7 @@ import org.apache.cassandra.io.sstable.Component; import org.apache.cassandra.io.sstable.format.SSTableReader; -import org.apache.cassandra.net.async.ByteBufDataOutputStreamPlus; +import org.apache.cassandra.net.AsyncStreamingOutputPlus; import org.apache.cassandra.streaming.ProgressInfo; import org.apache.cassandra.streaming.StreamManager; import org.apache.cassandra.streaming.StreamSession; @@ -58,11 +58,11 @@ public CassandraEntireSSTableStreamWriter(SSTableReader sstable, StreamSession s /** * Stream the entire file to given channel. *

- * + * TODO: this currently requires a companion thread, but could be performed entirely asynchronously * @param out where this writes data to * @throws IOException on any I/O error */ - public void write(ByteBufDataOutputStreamPlus out) throws IOException + public void write(AsyncStreamingOutputPlus out) throws IOException { long totalSize = manifest.totalSize(); logger.debug("[Stream #{}] Start streaming sstable {} to {}, repairedAt = {}, totalSize = {}", @@ -76,7 +76,7 @@ public void write(ByteBufDataOutputStreamPlus out) throws IOException for (Component component : manifest.components()) { - @SuppressWarnings("resource") // this is closed after the file is transferred by ByteBufDataOutputStreamPlus + @SuppressWarnings("resource") // this is closed after the file is transferred by AsyncChannelOutputPlus FileChannel in = new RandomAccessFile(sstable.descriptor.filenameFor(component), "r").getChannel(); // Total Length to transmit for this file @@ -90,7 +90,7 @@ public void write(ByteBufDataOutputStreamPlus out) throws IOException component, prettyPrintMemory(length)); - long bytesWritten = out.writeToChannel(in, limiter); + long bytesWritten = out.writeFileToChannel(in, limiter); progress += bytesWritten; session.progress(sstable.descriptor.filenameFor(component), ProgressInfo.Direction.OUT, bytesWritten, length); diff --git a/src/java/org/apache/cassandra/db/streaming/CassandraOutgoingFile.java b/src/java/org/apache/cassandra/db/streaming/CassandraOutgoingFile.java index c688fdf7f3e4..e8f5485844c5 100644 --- a/src/java/org/apache/cassandra/db/streaming/CassandraOutgoingFile.java +++ b/src/java/org/apache/cassandra/db/streaming/CassandraOutgoingFile.java @@ -20,7 +20,6 @@ import java.io.File; import java.io.IOException; -import java.util.Collection; import java.util.LinkedHashMap; import java.util.List; import java.util.Objects; @@ -41,7 +40,7 @@ import org.apache.cassandra.io.sstable.KeyIterator; import org.apache.cassandra.io.sstable.format.SSTableReader; import org.apache.cassandra.io.util.DataOutputStreamPlus; -import org.apache.cassandra.net.async.ByteBufDataOutputStreamPlus; +import org.apache.cassandra.net.AsyncStreamingOutputPlus; import org.apache.cassandra.schema.TableId; import org.apache.cassandra.streaming.OutgoingStream; import org.apache.cassandra.streaming.StreamOperation; @@ -164,10 +163,10 @@ public void write(StreamSession session, DataOutputStreamPlus out, int version) CassandraStreamHeader.serializer.serialize(header, out, version); out.flush(); - if (shouldStreamEntireSSTable() && out instanceof ByteBufDataOutputStreamPlus) + if (shouldStreamEntireSSTable() && out instanceof AsyncStreamingOutputPlus) { CassandraEntireSSTableStreamWriter writer = new CassandraEntireSSTableStreamWriter(sstable, session, manifest); - writer.write((ByteBufDataOutputStreamPlus) out); + writer.write((AsyncStreamingOutputPlus) out); } else { diff --git a/src/java/org/apache/cassandra/db/streaming/CassandraStreamReader.java b/src/java/org/apache/cassandra/db/streaming/CassandraStreamReader.java index 43371a95cd90..190f1360bbc2 100644 --- a/src/java/org/apache/cassandra/db/streaming/CassandraStreamReader.java +++ b/src/java/org/apache/cassandra/db/streaming/CassandraStreamReader.java @@ -28,7 +28,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import org.apache.cassandra.db.lifecycle.LifecycleTransaction; +import org.apache.cassandra.exceptions.UnknownColumnException; import org.apache.cassandra.io.sstable.format.SSTableReader; import org.apache.cassandra.io.util.TrackedDataInputPlus; import org.apache.cassandra.schema.TableId; @@ -47,10 +47,11 @@ import org.apache.cassandra.streaming.StreamSession; import org.apache.cassandra.streaming.compress.StreamCompressionInputStream; import org.apache.cassandra.streaming.messages.StreamMessageHeader; -import org.apache.cassandra.streaming.messages.StreamMessage; import org.apache.cassandra.utils.ByteBufferUtil; import org.apache.cassandra.utils.FBUtilities; +import static org.apache.cassandra.net.MessagingService.current_version; + /** * CassandraStreamReader reads from stream and writes to SSTable. */ @@ -114,7 +115,7 @@ public SSTableMultiWriter read(DataInputPlus inputPlus) throws IOException StreamDeserializer deserializer = null; SSTableMultiWriter writer = null; - try (StreamCompressionInputStream streamCompressionInputStream = new StreamCompressionInputStream(inputPlus, StreamMessage.CURRENT_VERSION)) + try (StreamCompressionInputStream streamCompressionInputStream = new StreamCompressionInputStream(inputPlus, current_version)) { TrackedDataInputPlus in = new TrackedDataInputPlus(streamCompressionInputStream); deserializer = new StreamDeserializer(cfs.metadata(), in, inputVersion, getHeader(cfs.metadata())); @@ -142,7 +143,7 @@ public SSTableMultiWriter read(DataInputPlus inputPlus) throws IOException } } - protected SerializationHeader getHeader(TableMetadata metadata) + protected SerializationHeader getHeader(TableMetadata metadata) throws UnknownColumnException { return header != null? header.toHeader(metadata) : null; //pre-3.0 sstable have no SerializationHeader } diff --git a/src/java/org/apache/cassandra/db/streaming/CassandraStreamReceiver.java b/src/java/org/apache/cassandra/db/streaming/CassandraStreamReceiver.java index b2b2ce5cf093..8338a178361a 100644 --- a/src/java/org/apache/cassandra/db/streaming/CassandraStreamReceiver.java +++ b/src/java/org/apache/cassandra/db/streaming/CassandraStreamReceiver.java @@ -28,7 +28,7 @@ import org.apache.cassandra.db.lifecycle.LifecycleNewTracker; import org.apache.cassandra.io.sstable.SSTable; -import org.apache.cassandra.streaming.StreamReceiveTask; + import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -273,7 +273,7 @@ public void cleanup() // the streamed sstables. if (requiresWritePath) { - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); abort(); } } diff --git a/src/java/org/apache/cassandra/db/streaming/CassandraStreamWriter.java b/src/java/org/apache/cassandra/db/streaming/CassandraStreamWriter.java index c6dd9a91e6d0..ffc663dd18cb 100644 --- a/src/java/org/apache/cassandra/db/streaming/CassandraStreamWriter.java +++ b/src/java/org/apache/cassandra/db/streaming/CassandraStreamWriter.java @@ -25,19 +25,25 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import net.jpountz.lz4.LZ4Compressor; +import net.jpountz.lz4.LZ4Factory; +import org.apache.cassandra.io.compress.BufferType; import org.apache.cassandra.io.sstable.Component; import org.apache.cassandra.io.sstable.format.SSTableReader; import org.apache.cassandra.io.util.ChannelProxy; import org.apache.cassandra.io.util.DataIntegrityMetadata; import org.apache.cassandra.io.util.DataIntegrityMetadata.ChecksumValidator; import org.apache.cassandra.io.util.DataOutputStreamPlus; -import org.apache.cassandra.io.util.FileUtils; +import org.apache.cassandra.net.AsyncStreamingOutputPlus; import org.apache.cassandra.streaming.ProgressInfo; import org.apache.cassandra.streaming.StreamManager; import org.apache.cassandra.streaming.StreamManager.StreamRateLimiter; import org.apache.cassandra.streaming.StreamSession; -import org.apache.cassandra.streaming.compress.ByteBufCompressionDataOutputStreamPlus; +import org.apache.cassandra.streaming.async.StreamCompressionSerializer; import org.apache.cassandra.utils.FBUtilities; +import org.apache.cassandra.utils.memory.BufferPool; + +import static org.apache.cassandra.net.MessagingService.current_version; /** * CassandraStreamWriter writes given section of the SSTable to given channel. @@ -49,6 +55,7 @@ public class CassandraStreamWriter private static final Logger logger = LoggerFactory.getLogger(CassandraStreamWriter.class); protected final SSTableReader sstable; + private final LZ4Compressor compressor = LZ4Factory.fastestInstance().fastCompressor(); protected final Collection sections; protected final StreamRateLimiter limiter; protected final StreamSession session; @@ -75,6 +82,7 @@ public void write(DataOutputStreamPlus output) throws IOException logger.debug("[Stream #{}] Start streaming file {} to {}, repairedAt = {}, totalSize = {}", session.planId(), sstable.getFilename(), session.peer, sstable.getSSTableMetadata().repairedAt, totalSize); + AsyncStreamingOutputPlus out = (AsyncStreamingOutputPlus) output; try(ChannelProxy proxy = sstable.getDataChannel().sharedCopy(); ChecksumValidator validator = new File(sstable.descriptor.filenameFor(Component.CRC)).exists() ? DataIntegrityMetadata.checksumValidator(sstable.descriptor) @@ -85,38 +93,35 @@ public void write(DataOutputStreamPlus output) throws IOException // setting up data compression stream long progress = 0L; - try (DataOutputStreamPlus compressedOutput = new ByteBufCompressionDataOutputStreamPlus(output, limiter)) + // stream each of the required sections of the file + for (SSTableReader.PartitionPositionBounds section : sections) { - // stream each of the required sections of the file - for (SSTableReader.PartitionPositionBounds section : sections) + long start = validator == null ? section.lowerPosition : validator.chunkStart(section.lowerPosition); + // if the transfer does not start on the valididator's chunk boundary, this is the number of bytes to offset by + int transferOffset = (int) (section.lowerPosition - start); + if (validator != null) + validator.seek(start); + + // length of the section to read + long length = section.upperPosition - start; + // tracks write progress + long bytesRead = 0; + while (bytesRead < length) { - long start = validator == null ? section.lowerPosition : validator.chunkStart(section.lowerPosition); - // if the transfer does not start on the valididator's chunk boundary, this is the number of bytes to offset by - int transferOffset = (int) (section.lowerPosition - start); - if (validator != null) - validator.seek(start); - - // length of the section to read - long length = section.upperPosition - start; - // tracks write progress - long bytesRead = 0; - while (bytesRead < length) - { - int toTransfer = (int) Math.min(bufferSize, length - bytesRead); - long lastBytesRead = write(proxy, validator, compressedOutput, start, transferOffset, toTransfer, bufferSize); - start += lastBytesRead; - bytesRead += lastBytesRead; - progress += (lastBytesRead - transferOffset); - session.progress(sstable.descriptor.filenameFor(Component.DATA), ProgressInfo.Direction.OUT, progress, totalSize); - transferOffset = 0; - } - - // make sure that current section is sent - output.flush(); + int toTransfer = (int) Math.min(bufferSize, length - bytesRead); + long lastBytesRead = write(proxy, validator, out, start, transferOffset, toTransfer, bufferSize); + start += lastBytesRead; + bytesRead += lastBytesRead; + progress += (lastBytesRead - transferOffset); + session.progress(sstable.descriptor.filenameFor(Component.DATA), ProgressInfo.Direction.OUT, progress, totalSize); + transferOffset = 0; } - logger.debug("[Stream #{}] Finished streaming file {} to {}, bytesTransferred = {}, totalSize = {}", - session.planId(), sstable.getFilename(), session.peer, FBUtilities.prettyPrintMemory(progress), FBUtilities.prettyPrintMemory(totalSize)); + + // make sure that current section is sent + out.flush(); } + logger.debug("[Stream #{}] Finished streaming file {} to {}, bytesTransferred = {}, totalSize = {}", + session.planId(), sstable.getFilename(), session.peer, FBUtilities.prettyPrintMemory(progress), FBUtilities.prettyPrintMemory(totalSize)); } } @@ -141,14 +146,14 @@ protected long totalSize() * * @throws java.io.IOException on any I/O error */ - protected long write(ChannelProxy proxy, ChecksumValidator validator, DataOutputStreamPlus output, long start, int transferOffset, int toTransfer, int bufferSize) throws IOException + protected long write(ChannelProxy proxy, ChecksumValidator validator, AsyncStreamingOutputPlus output, long start, int transferOffset, int toTransfer, int bufferSize) throws IOException { // the count of bytes to read off disk int minReadable = (int) Math.min(bufferSize, proxy.size() - start); // this buffer will hold the data from disk. as it will be compressed on the fly by - // ByteBufCompressionDataOutputStreamPlus.write(ByteBuffer), we can release this buffer as soon as we can. - ByteBuffer buffer = ByteBuffer.allocateDirect(minReadable); + // AsyncChannelCompressedStreamWriter.write(ByteBuffer), we can release this buffer as soon as we can. + ByteBuffer buffer = BufferPool.get(minReadable, BufferType.OFF_HEAP); try { int readCount = proxy.read(buffer, start); @@ -163,11 +168,11 @@ protected long write(ChannelProxy proxy, ChecksumValidator validator, DataOutput buffer.position(transferOffset); buffer.limit(transferOffset + (toTransfer - transferOffset)); - output.write(buffer); + output.writeToChannel(StreamCompressionSerializer.serialize(compressor, buffer, current_version), limiter); } finally { - FileUtils.clean(buffer); + BufferPool.put(buffer); } return toTransfer; diff --git a/src/java/org/apache/cassandra/db/transform/BasePartitions.java b/src/java/org/apache/cassandra/db/transform/BasePartitions.java index 7b0f56c40293..464ae6f88bc9 100644 --- a/src/java/org/apache/cassandra/db/transform/BasePartitions.java +++ b/src/java/org/apache/cassandra/db/transform/BasePartitions.java @@ -90,7 +90,7 @@ public final boolean hasNext() Transformation[] fs = stack; int len = length; - while (!stop.isSignalled && input.hasNext()) + while (!stop.isSignalled && !stopChild.isSignalled && input.hasNext()) { next = input.next(); for (int i = 0 ; next != null & i < len ; i++) @@ -103,7 +103,7 @@ public final boolean hasNext() } } - if (stop.isSignalled || stopChild.isSignalled || !hasMoreContents()) + if (stop.isSignalled || !hasMoreContents()) return false; } return true; diff --git a/src/java/org/apache/cassandra/db/view/TableViews.java b/src/java/org/apache/cassandra/db/view/TableViews.java index 09490e8bd7ec..9d4d997a774b 100644 --- a/src/java/org/apache/cassandra/db/view/TableViews.java +++ b/src/java/org/apache/cassandra/db/view/TableViews.java @@ -23,7 +23,6 @@ import java.util.concurrent.atomic.AtomicLong; import java.util.stream.Collectors; -import com.google.common.collect.ImmutableMap; import com.google.common.collect.Iterables; import com.google.common.collect.Iterators; import com.google.common.collect.PeekingIterator; @@ -96,7 +95,7 @@ public Iterable allViewsCfs() public void forceBlockingFlush() { for (ColumnFamilyStore viewCfs : allViewsCfs()) - viewCfs.forceBlockingFlush(); + viewCfs.forceBlockingFlushToSSTable(); } public void dumpMemtables() diff --git a/src/java/org/apache/cassandra/db/view/ViewBuilder.java b/src/java/org/apache/cassandra/db/view/ViewBuilder.java index 67172973ee82..6d92c73cd739 100644 --- a/src/java/org/apache/cassandra/db/view/ViewBuilder.java +++ b/src/java/org/apache/cassandra/db/view/ViewBuilder.java @@ -96,7 +96,7 @@ public void start() logger.debug("Starting build of view({}.{}). Flushing base table {}.{}", ksName, view.name, ksName, baseCfs.name); - baseCfs.forceBlockingFlush(); + baseCfs.forceBlockingFlushToSSTable(); loadStatusAndBuild(); } diff --git a/src/java/org/apache/cassandra/db/virtual/AbstractVirtualTable.java b/src/java/org/apache/cassandra/db/virtual/AbstractVirtualTable.java index 2998b779e263..6c49b9a1b00a 100644 --- a/src/java/org/apache/cassandra/db/virtual/AbstractVirtualTable.java +++ b/src/java/org/apache/cassandra/db/virtual/AbstractVirtualTable.java @@ -42,7 +42,7 @@ */ public abstract class AbstractVirtualTable implements VirtualTable { - private final TableMetadata metadata; + protected final TableMetadata metadata; protected AbstractVirtualTable(TableMetadata metadata) { diff --git a/src/java/org/apache/cassandra/db/virtual/InternodeInboundTable.java b/src/java/org/apache/cassandra/db/virtual/InternodeInboundTable.java new file mode 100644 index 000000000000..b0afe8f699a9 --- /dev/null +++ b/src/java/org/apache/cassandra/db/virtual/InternodeInboundTable.java @@ -0,0 +1,135 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.cassandra.db.virtual; + +import java.net.InetAddress; +import java.nio.ByteBuffer; + +import org.apache.cassandra.config.DatabaseDescriptor; +import org.apache.cassandra.db.DecoratedKey; +import org.apache.cassandra.db.marshal.CompositeType; +import org.apache.cassandra.db.marshal.InetAddressType; +import org.apache.cassandra.db.marshal.Int32Type; +import org.apache.cassandra.db.marshal.LongType; +import org.apache.cassandra.db.marshal.UTF8Type; +import org.apache.cassandra.dht.LocalPartitioner; +import org.apache.cassandra.locator.InetAddressAndPort; +import org.apache.cassandra.net.MessagingService; +import org.apache.cassandra.net.InboundMessageHandlers; +import org.apache.cassandra.schema.TableMetadata; + +public final class InternodeInboundTable extends AbstractVirtualTable +{ + private static final String ADDRESS = "address"; + private static final String PORT = "port"; + private static final String DC = "dc"; + private static final String RACK = "rack"; + + private static final String USING_BYTES = "using_bytes"; + private static final String USING_RESERVE_BYTES = "using_reserve_bytes"; + private static final String CORRUPT_FRAMES_RECOVERED = "corrupt_frames_recovered"; + private static final String CORRUPT_FRAMES_UNRECOVERED = "corrupt_frames_unrecovered"; + private static final String ERROR_BYTES = "error_bytes"; + private static final String ERROR_COUNT = "error_count"; + private static final String EXPIRED_BYTES = "expired_bytes"; + private static final String EXPIRED_COUNT = "expired_count"; + private static final String SCHEDULED_BYTES = "scheduled_bytes"; + private static final String SCHEDULED_COUNT = "scheduled_count"; + private static final String PROCESSED_BYTES = "processed_bytes"; + private static final String PROCESSED_COUNT = "processed_count"; + private static final String RECEIVED_BYTES = "received_bytes"; + private static final String RECEIVED_COUNT = "received_count"; + private static final String THROTTLED_COUNT = "throttled_count"; + private static final String THROTTLED_NANOS = "throttled_nanos"; + + InternodeInboundTable(String keyspace) + { + super(TableMetadata.builder(keyspace, "internode_inbound") + .kind(TableMetadata.Kind.VIRTUAL) + .partitioner(new LocalPartitioner(CompositeType.getInstance(InetAddressType.instance, Int32Type.instance))) + .addPartitionKeyColumn(ADDRESS, InetAddressType.instance) + .addPartitionKeyColumn(PORT, Int32Type.instance) + .addClusteringColumn(DC, UTF8Type.instance) + .addClusteringColumn(RACK, UTF8Type.instance) + .addRegularColumn(USING_BYTES, LongType.instance) + .addRegularColumn(USING_RESERVE_BYTES, LongType.instance) + .addRegularColumn(CORRUPT_FRAMES_RECOVERED, LongType.instance) + .addRegularColumn(CORRUPT_FRAMES_UNRECOVERED, LongType.instance) + .addRegularColumn(ERROR_BYTES, LongType.instance) + .addRegularColumn(ERROR_COUNT, LongType.instance) + .addRegularColumn(EXPIRED_BYTES, LongType.instance) + .addRegularColumn(EXPIRED_COUNT, LongType.instance) + .addRegularColumn(SCHEDULED_BYTES, LongType.instance) + .addRegularColumn(SCHEDULED_COUNT, LongType.instance) + .addRegularColumn(PROCESSED_BYTES, LongType.instance) + .addRegularColumn(PROCESSED_COUNT, LongType.instance) + .addRegularColumn(RECEIVED_BYTES, LongType.instance) + .addRegularColumn(RECEIVED_COUNT, LongType.instance) + .addRegularColumn(THROTTLED_COUNT, LongType.instance) + .addRegularColumn(THROTTLED_NANOS, LongType.instance) + .build()); + } + + @Override + public DataSet data(DecoratedKey partitionKey) + { + ByteBuffer[] addressAndPortBytes = ((CompositeType) metadata().partitionKeyType).split(partitionKey.getKey()); + InetAddress address = InetAddressType.instance.compose(addressAndPortBytes[0]); + int port = Int32Type.instance.compose(addressAndPortBytes[1]); + InetAddressAndPort addressAndPort = InetAddressAndPort.getByAddressOverrideDefaults(address, port); + + SimpleDataSet result = new SimpleDataSet(metadata()); + InboundMessageHandlers handlers = MessagingService.instance().messageHandlers.get(addressAndPort); + if (null != handlers) + addRow(result, addressAndPort, handlers); + return result; + } + + @Override + public DataSet data() + { + SimpleDataSet result = new SimpleDataSet(metadata()); + MessagingService.instance() + .messageHandlers + .forEach((addressAndPort, handlers) -> addRow(result, addressAndPort, handlers)); + return result; + } + + private void addRow(SimpleDataSet dataSet, InetAddressAndPort addressAndPort, InboundMessageHandlers handlers) + { + String dc = DatabaseDescriptor.getEndpointSnitch().getDatacenter(addressAndPort); + String rack = DatabaseDescriptor.getEndpointSnitch().getRack(addressAndPort); + dataSet.row(addressAndPort.address, addressAndPort.port, dc, rack) + .column(USING_BYTES, handlers.usingCapacity()) + .column(USING_RESERVE_BYTES, handlers.usingEndpointReserveCapacity()) + .column(CORRUPT_FRAMES_RECOVERED, handlers.corruptFramesRecovered()) + .column(CORRUPT_FRAMES_UNRECOVERED, handlers.corruptFramesUnrecovered()) + .column(ERROR_BYTES, handlers.errorBytes()) + .column(ERROR_COUNT, handlers.errorCount()) + .column(EXPIRED_BYTES, handlers.expiredBytes()) + .column(EXPIRED_COUNT, handlers.expiredCount()) + .column(SCHEDULED_BYTES, handlers.scheduledBytes()) + .column(SCHEDULED_COUNT, handlers.scheduledCount()) + .column(PROCESSED_BYTES, handlers.processedBytes()) + .column(PROCESSED_COUNT, handlers.processedCount()) + .column(RECEIVED_BYTES, handlers.receivedBytes()) + .column(RECEIVED_COUNT, handlers.receivedCount()) + .column(THROTTLED_COUNT, handlers.throttledCount()) + .column(THROTTLED_NANOS, handlers.throttledNanos()); + } +} diff --git a/src/java/org/apache/cassandra/db/virtual/InternodeOutboundTable.java b/src/java/org/apache/cassandra/db/virtual/InternodeOutboundTable.java new file mode 100644 index 000000000000..87b38235e586 --- /dev/null +++ b/src/java/org/apache/cassandra/db/virtual/InternodeOutboundTable.java @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.cassandra.db.virtual; + +import java.net.InetAddress; +import java.nio.ByteBuffer; +import java.util.function.ToLongFunction; + +import org.apache.cassandra.config.DatabaseDescriptor; +import org.apache.cassandra.db.DecoratedKey; +import org.apache.cassandra.db.marshal.CompositeType; +import org.apache.cassandra.db.marshal.InetAddressType; +import org.apache.cassandra.db.marshal.Int32Type; +import org.apache.cassandra.db.marshal.LongType; +import org.apache.cassandra.db.marshal.UTF8Type; +import org.apache.cassandra.dht.LocalPartitioner; +import org.apache.cassandra.locator.InetAddressAndPort; +import org.apache.cassandra.net.MessagingService; +import org.apache.cassandra.net.OutboundConnection; +import org.apache.cassandra.net.OutboundConnections; +import org.apache.cassandra.schema.TableMetadata; + +public final class InternodeOutboundTable extends AbstractVirtualTable +{ + private static final String ADDRESS = "address"; + private static final String PORT = "port"; + private static final String DC = "dc"; + private static final String RACK = "rack"; + + private static final String USING_BYTES = "using_bytes"; + private static final String USING_RESERVE_BYTES = "using_reserve_bytes"; + private static final String PENDING_COUNT = "pending_count"; + private static final String PENDING_BYTES = "pending_bytes"; + private static final String SENT_COUNT = "sent_count"; + private static final String SENT_BYTES = "sent_bytes"; + private static final String EXPIRED_COUNT = "expired_count"; + private static final String EXPIRED_BYTES = "expired_bytes"; + private static final String ERROR_COUNT = "error_count"; + private static final String ERROR_BYTES = "error_bytes"; + private static final String OVERLOAD_COUNT = "overload_count"; + private static final String OVERLOAD_BYTES = "overload_bytes"; + private static final String ACTIVE_CONNECTION_COUNT = "active_connections"; + private static final String CONNECTION_ATTEMPTS = "connection_attempts"; + private static final String SUCCESSFUL_CONNECTION_ATTEMPTS = "successful_connection_attempts"; + + InternodeOutboundTable(String keyspace) + { + super(TableMetadata.builder(keyspace, "internode_outbound") + .kind(TableMetadata.Kind.VIRTUAL) + .partitioner(new LocalPartitioner(CompositeType.getInstance(InetAddressType.instance, Int32Type.instance))) + .addPartitionKeyColumn(ADDRESS, InetAddressType.instance) + .addPartitionKeyColumn(PORT, Int32Type.instance) + .addClusteringColumn(DC, UTF8Type.instance) + .addClusteringColumn(RACK, UTF8Type.instance) + .addRegularColumn(USING_BYTES, LongType.instance) + .addRegularColumn(USING_RESERVE_BYTES, LongType.instance) + .addRegularColumn(PENDING_COUNT, LongType.instance) + .addRegularColumn(PENDING_BYTES, LongType.instance) + .addRegularColumn(SENT_COUNT, LongType.instance) + .addRegularColumn(SENT_BYTES, LongType.instance) + .addRegularColumn(EXPIRED_COUNT, LongType.instance) + .addRegularColumn(EXPIRED_BYTES, LongType.instance) + .addRegularColumn(ERROR_COUNT, LongType.instance) + .addRegularColumn(ERROR_BYTES, LongType.instance) + .addRegularColumn(OVERLOAD_COUNT, LongType.instance) + .addRegularColumn(OVERLOAD_BYTES, LongType.instance) + .addRegularColumn(ACTIVE_CONNECTION_COUNT, LongType.instance) + .addRegularColumn(CONNECTION_ATTEMPTS, LongType.instance) + .addRegularColumn(SUCCESSFUL_CONNECTION_ATTEMPTS, LongType.instance) + .build()); + } + + @Override + public DataSet data(DecoratedKey partitionKey) + { + ByteBuffer[] addressAndPortBytes = ((CompositeType) metadata().partitionKeyType).split(partitionKey.getKey()); + InetAddress address = InetAddressType.instance.compose(addressAndPortBytes[0]); + int port = Int32Type.instance.compose(addressAndPortBytes[1]); + InetAddressAndPort addressAndPort = InetAddressAndPort.getByAddressOverrideDefaults(address, port); + + SimpleDataSet result = new SimpleDataSet(metadata()); + OutboundConnections connections = MessagingService.instance().channelManagers.get(addressAndPort); + if (null != connections) + addRow(result, addressAndPort, connections); + return result; + } + + @Override + public DataSet data() + { + SimpleDataSet result = new SimpleDataSet(metadata()); + MessagingService.instance() + .channelManagers + .forEach((addressAndPort, connections) -> addRow(result, addressAndPort, connections)); + return result; + } + + private void addRow(SimpleDataSet dataSet, InetAddressAndPort addressAndPort, OutboundConnections connections) + { + String dc = DatabaseDescriptor.getEndpointSnitch().getDatacenter(addressAndPort); + String rack = DatabaseDescriptor.getEndpointSnitch().getRack(addressAndPort); + long pendingBytes = sum(connections, OutboundConnection::pendingBytes); + dataSet.row(addressAndPort.address, addressAndPort.port, dc, rack) + .column(USING_BYTES, pendingBytes) + .column(USING_RESERVE_BYTES, connections.usingReserveBytes()) + .column(PENDING_COUNT, sum(connections, OutboundConnection::pendingCount)) + .column(PENDING_BYTES, pendingBytes) + .column(SENT_COUNT, sum(connections, OutboundConnection::sentCount)) + .column(SENT_BYTES, sum(connections, OutboundConnection::sentBytes)) + .column(EXPIRED_COUNT, sum(connections, OutboundConnection::expiredCount)) + .column(EXPIRED_BYTES, sum(connections, OutboundConnection::expiredBytes)) + .column(ERROR_COUNT, sum(connections, OutboundConnection::errorCount)) + .column(ERROR_BYTES, sum(connections, OutboundConnection::errorBytes)) + .column(OVERLOAD_COUNT, sum(connections, OutboundConnection::overloadedCount)) + .column(OVERLOAD_BYTES, sum(connections, OutboundConnection::overloadedBytes)) + .column(ACTIVE_CONNECTION_COUNT, sum(connections, c -> c.isConnected() ? 1 : 0)) + .column(CONNECTION_ATTEMPTS, sum(connections, OutboundConnection::connectionAttempts)) + .column(SUCCESSFUL_CONNECTION_ATTEMPTS, sum(connections, OutboundConnection::successfulConnections)); + } + + private static long sum(OutboundConnections connections, ToLongFunction f) + { + return f.applyAsLong(connections.small) + f.applyAsLong(connections.large) + f.applyAsLong(connections.urgent); + } +} diff --git a/src/java/org/apache/cassandra/db/virtual/SettingsTable.java b/src/java/org/apache/cassandra/db/virtual/SettingsTable.java index 34debc6b09ee..048d4ba35da3 100644 --- a/src/java/org/apache/cassandra/db/virtual/SettingsTable.java +++ b/src/java/org/apache/cassandra/db/virtual/SettingsTable.java @@ -163,7 +163,7 @@ private void addEncryptionOptions(SimpleDataSet result, Field f) result.row(f.getName() + "_enabled").column(VALUE, Boolean.toString(value.enabled)); result.row(f.getName() + "_algorithm").column(VALUE, value.algorithm); result.row(f.getName() + "_protocol").column(VALUE, value.protocol); - result.row(f.getName() + "_cipher_suites").column(VALUE, Arrays.toString(value.cipher_suites)); + result.row(f.getName() + "_cipher_suites").column(VALUE, value.cipher_suites.toString()); result.row(f.getName() + "_client_auth").column(VALUE, Boolean.toString(value.require_client_auth)); result.row(f.getName() + "_endpoint_verification").column(VALUE, Boolean.toString(value.require_endpoint_verification)); result.row(f.getName() + "_optional").column(VALUE, Boolean.toString(value.optional)); diff --git a/src/java/org/apache/cassandra/db/virtual/SimpleDataSet.java b/src/java/org/apache/cassandra/db/virtual/SimpleDataSet.java index bf401401d28f..00acaedc02f0 100644 --- a/src/java/org/apache/cassandra/db/virtual/SimpleDataSet.java +++ b/src/java/org/apache/cassandra/db/virtual/SimpleDataSet.java @@ -73,6 +73,8 @@ public SimpleDataSet column(String columnName, Object value) { if (null == currentRow) throw new IllegalStateException(); + if (null == columnName) + throw new IllegalStateException(String.format("Invalid column: %s=%s for %s", columnName, value, currentRow)); currentRow.add(columnName, value); return this; } @@ -181,6 +183,11 @@ private org.apache.cassandra.db.rows.Row toTableRow(RegularAndStaticColumns colu return builder.build(); } + + public String toString() + { + return "Row[...:" + clustering.toString(metadata)+']'; + } } @SuppressWarnings("unchecked") diff --git a/src/java/org/apache/cassandra/db/virtual/SystemViewsKeyspace.java b/src/java/org/apache/cassandra/db/virtual/SystemViewsKeyspace.java index f85991acc159..abcdf87c26a6 100644 --- a/src/java/org/apache/cassandra/db/virtual/SystemViewsKeyspace.java +++ b/src/java/org/apache/cassandra/db/virtual/SystemViewsKeyspace.java @@ -27,10 +27,15 @@ public final class SystemViewsKeyspace extends VirtualKeyspace private SystemViewsKeyspace() { - super(NAME, ImmutableList.of(new CachesTable(NAME), - new ClientsTable(NAME), - new SettingsTable(NAME), - new SSTableTasksTable(NAME), - new ThreadPoolsTable(NAME))); + super(NAME, new ImmutableList.Builder() + .add(new CachesTable(NAME)) + .add(new ClientsTable(NAME)) + .add(new SettingsTable(NAME)) + .add(new SSTableTasksTable(NAME)) + .add(new ThreadPoolsTable(NAME)) + .add(new InternodeOutboundTable(NAME)) + .add(new InternodeInboundTable(NAME)) + .addAll(TableMetricTables.getAll(NAME)) + .build()); } } diff --git a/src/java/org/apache/cassandra/db/virtual/TableMetricTables.java b/src/java/org/apache/cassandra/db/virtual/TableMetricTables.java new file mode 100644 index 000000000000..4a043adf0765 --- /dev/null +++ b/src/java/org/apache/cassandra/db/virtual/TableMetricTables.java @@ -0,0 +1,265 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.db.virtual; + +import java.math.BigDecimal; +import java.util.Collection; +import java.util.function.Function; + +import com.google.common.collect.ImmutableList; +import org.apache.commons.math3.util.Precision; + +import com.codahale.metrics.Counting; +import com.codahale.metrics.Gauge; +import com.codahale.metrics.Metered; +import com.codahale.metrics.Metric; +import com.codahale.metrics.Sampling; +import com.codahale.metrics.Snapshot; +import org.apache.cassandra.db.ColumnFamilyStore; +import org.apache.cassandra.db.Keyspace; +import org.apache.cassandra.db.marshal.AbstractType; +import org.apache.cassandra.db.marshal.CompositeType; +import org.apache.cassandra.db.marshal.DoubleType; +import org.apache.cassandra.db.marshal.LongType; +import org.apache.cassandra.db.marshal.UTF8Type; +import org.apache.cassandra.dht.IPartitioner; +import org.apache.cassandra.dht.LocalPartitioner; +import org.apache.cassandra.metrics.TableMetrics; +import org.apache.cassandra.schema.TableMetadata; + +/** + * Contains multiple the Table Metric virtual tables. This is not a direct wrapper over the Metrics like with JMX but a + * view to the metrics so that the underlying mechanism can change but still give same appearance (like nodetool). + */ +public class TableMetricTables +{ + private final static String KEYSPACE_NAME = "keyspace_name"; + private final static String TABLE_NAME = "table_name"; + private final static String P50 = "50th"; + private final static String P99 = "99th"; + private final static String MAX = "max"; + private final static String RATE = "per_second"; + private final static double BYTES_TO_MIB = 1.0 / (1024 * 1024); + private final static double NS_TO_MS = 0.000001; + + private final static AbstractType TYPE = CompositeType.getInstance(UTF8Type.instance, + UTF8Type.instance); + private final static IPartitioner PARTITIONER = new LocalPartitioner(TYPE); + + /** + * Generates all table metric tables in a collection + */ + public static Collection getAll(String name) + { + return ImmutableList.of( + new LatencyTableMetric(name, "local_read_latency", t -> t.readLatency.latency), + new LatencyTableMetric(name, "local_scan_latency", t -> t.rangeLatency.latency), + new LatencyTableMetric(name, "coordinator_read_latency", t -> t.coordinatorReadLatency), + new LatencyTableMetric(name, "coordinator_scan_latency", t -> t.coordinatorScanLatency), + new LatencyTableMetric(name, "local_write_latency", t -> t.writeLatency.latency), + new LatencyTableMetric(name, "coordinator_write_latency", t -> t.coordinatorWriteLatency), + new HistogramTableMetric(name, "tombstones_per_read", t -> t.tombstoneScannedHistogram.cf), + new HistogramTableMetric(name, "rows_per_read", t -> t.liveScannedHistogram.cf), + new StorageTableMetric(name, "disk_usage", (TableMetrics t) -> t.totalDiskSpaceUsed), + new StorageTableMetric(name, "max_partition_size", (TableMetrics t) -> t.maxPartitionSize)); + } + + /** + * A table that describes a some amount of disk on space in a Counter or Gauge + */ + private static class StorageTableMetric extends TableMetricTable + { + interface GaugeFunction extends Function> {} + interface CountingFunction extends Function {} + + StorageTableMetric(String keyspace, String table, CountingFunction func) + { + super(keyspace, table, func, "mebibytes", LongType.instance, ""); + } + + StorageTableMetric(String keyspace, String table, GaugeFunction func) + { + super(keyspace, table, func, "mebibytes", LongType.instance, ""); + } + + /** + * Convert bytes to mebibytes, always round up to nearest MiB + */ + public void add(SimpleDataSet result, String column, long value) + { + result.column(column, (long) Math.ceil(value * BYTES_TO_MIB)); + } + } + + /** + * A table that describes a Latency metric, specifically a Timer + */ + private static class HistogramTableMetric extends TableMetricTable + { + HistogramTableMetric(String keyspace, String table, Function func) + { + this(keyspace, table, func, ""); + } + + HistogramTableMetric(String keyspace, String table, Function func, String suffix) + { + super(keyspace, table, func, "count", LongType.instance, suffix); + } + + /** + * When displaying in cqlsh if we allow doubles to be too precise we get scientific notation which is hard to + * read so round off at 0.000. + */ + public void add(SimpleDataSet result, String column, double value) + { + result.column(column, Precision.round(value, 3, BigDecimal.ROUND_HALF_UP)); + } + } + + /** + * A table that describes a Latency metric, specifically a Timer + */ + private static class LatencyTableMetric extends HistogramTableMetric + { + LatencyTableMetric(String keyspace, String table, Function func) + { + super(keyspace, table, func, "_ms"); + } + + /** + * For the metrics that are time based, convert to to milliseconds + */ + public void add(SimpleDataSet result, String column, double value) + { + if (column.endsWith(suffix)) + value *= NS_TO_MS; + + super.add(result, column, value); + } + } + + /** + * Abstraction over the Metrics Gauge, Counter, and Timer that will turn it into a (keyspace_name, table_name) + * table. + */ + private static class TableMetricTable extends AbstractVirtualTable + { + final Function func; + final String columnName; + final String suffix; + + TableMetricTable(String keyspace, String table, Function func, + String colName, AbstractType colType, String suffix) + { + super(buildMetadata(keyspace, table, func, colName, colType, suffix)); + this.func = func; + this.columnName = colName; + this.suffix = suffix; + } + + public void add(SimpleDataSet result, String column, double value) + { + result.column(column, value); + } + + public void add(SimpleDataSet result, String column, long value) + { + result.column(column, value); + } + + public DataSet data() + { + SimpleDataSet result = new SimpleDataSet(metadata()); + + // Iterate over all tables and get metric by function + for (ColumnFamilyStore cfs : ColumnFamilyStore.all()) + { + Metric metric = func.apply(cfs.metric); + + // set new partition for this table + result.row(cfs.keyspace.getName(), cfs.name); + + // extract information by metric type and put it in row based on implementation of `add` + if (metric instanceof Counting) + { + add(result, columnName, ((Counting) metric).getCount()); + if (metric instanceof Sampling) + { + Sampling histo = (Sampling) metric; + Snapshot snapshot = histo.getSnapshot(); + // EstimatedHistogram keeping them in ns is hard to parse as a human so convert to ms + add(result, P50 + suffix, snapshot.getMedian()); + add(result, P99 + suffix, snapshot.get99thPercentile()); + add(result, MAX + suffix, (double) snapshot.getMax()); + } + if (metric instanceof Metered) + { + Metered timer = (Metered) metric; + add(result, RATE, timer.getFiveMinuteRate()); + } + } + else if (metric instanceof Gauge) + { + add(result, columnName, (long) ((Gauge) metric).getValue()); + } + } + return result; + } + } + + /** + * Identify the type of Metric it is (gauge, counter etc) abd create the TableMetadata. The column name + * and type for a counter/gauge is formatted differently based on the units (bytes/time) so allowed to + * be set. + */ + private static TableMetadata buildMetadata(String keyspace, String table, Function func, + String colName, AbstractType colType, String suffix) + { + TableMetadata.Builder metadata = TableMetadata.builder(keyspace, table) + .kind(TableMetadata.Kind.VIRTUAL) + .addPartitionKeyColumn(KEYSPACE_NAME, UTF8Type.instance) + .addPartitionKeyColumn(TABLE_NAME, UTF8Type.instance) + .partitioner(PARTITIONER); + + // get a table from system keyspace and get metric from it for determining type of metric + Keyspace system = Keyspace.system().iterator().next(); + Metric test = func.apply(system.getColumnFamilyStores().iterator().next().metric); + + if (test instanceof Counting) + { + metadata.addRegularColumn(colName, colType); + // if it has a Histogram include some information about distribution + if (test instanceof Sampling) + { + metadata.addRegularColumn(P50 + suffix, DoubleType.instance) + .addRegularColumn(P99 + suffix, DoubleType.instance) + .addRegularColumn(MAX + suffix, DoubleType.instance); + } + if (test instanceof Metered) + { + metadata.addRegularColumn(RATE, DoubleType.instance); + } + } + else if (test instanceof Gauge) + { + metadata.addRegularColumn(colName, colType); + } + return metadata.build(); + } +} diff --git a/src/java/org/apache/cassandra/db/virtual/VirtualMutation.java b/src/java/org/apache/cassandra/db/virtual/VirtualMutation.java index dc32c8cca9ce..6db0acdf2613 100644 --- a/src/java/org/apache/cassandra/db/virtual/VirtualMutation.java +++ b/src/java/org/apache/cassandra/db/virtual/VirtualMutation.java @@ -18,6 +18,7 @@ package org.apache.cassandra.db.virtual; import java.util.Collection; +import java.util.concurrent.TimeUnit; import com.google.common.base.MoreObjects; import com.google.common.collect.ImmutableMap; @@ -76,9 +77,9 @@ public DecoratedKey key() } @Override - public long getTimeout() + public long getTimeout(TimeUnit unit) { - return DatabaseDescriptor.getWriteRpcTimeout(); + return DatabaseDescriptor.getWriteRpcTimeout(unit); } @Override diff --git a/src/java/org/apache/cassandra/dht/BootStrapper.java b/src/java/org/apache/cassandra/dht/BootStrapper.java index cef605eb87b7..94bf283e435b 100644 --- a/src/java/org/apache/cassandra/dht/BootStrapper.java +++ b/src/java/org/apache/cassandra/dht/BootStrapper.java @@ -17,7 +17,6 @@ */ package org.apache.cassandra.dht; -import java.io.IOException; import java.util.*; import java.util.concurrent.atomic.AtomicInteger; @@ -25,19 +24,16 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; + import org.apache.cassandra.config.DatabaseDescriptor; import org.apache.cassandra.schema.Schema; import org.apache.cassandra.db.Keyspace; -import org.apache.cassandra.db.TypeSizes; import org.apache.cassandra.dht.tokenallocator.TokenAllocation; import org.apache.cassandra.exceptions.ConfigurationException; -import org.apache.cassandra.gms.FailureDetector; import org.apache.cassandra.gms.Gossiper; -import org.apache.cassandra.io.IVersionedSerializer; -import org.apache.cassandra.io.util.DataInputPlus; -import org.apache.cassandra.io.util.DataOutputPlus; import org.apache.cassandra.locator.AbstractReplicationStrategy; import org.apache.cassandra.locator.InetAddressAndPort; +import org.apache.cassandra.locator.NetworkTopologyStrategy; import org.apache.cassandra.locator.TokenMetadata; import org.apache.cassandra.service.StorageService; import org.apache.cassandra.streaming.*; @@ -159,6 +155,7 @@ public void onFailure(Throwable throwable) public static Collection getBootstrapTokens(final TokenMetadata metadata, InetAddressAndPort address, int schemaWaitDelay) throws ConfigurationException { String allocationKeyspace = DatabaseDescriptor.getAllocateTokensForKeyspace(); + Integer allocationLocalRf = DatabaseDescriptor.getAllocateTokensForLocalRf(); Collection initialTokens = DatabaseDescriptor.getInitialTokens(); if (initialTokens.size() > 0 && allocationKeyspace != null) logger.warn("manually specified tokens override automatic allocation"); @@ -178,6 +175,9 @@ public static Collection getBootstrapTokens(final TokenMetadata metadata, if (allocationKeyspace != null) return allocateTokens(metadata, address, allocationKeyspace, numTokens, schemaWaitDelay); + if (allocationLocalRf != null) + return allocateTokens(metadata, address, allocationLocalRf, numTokens, schemaWaitDelay); + if (numTokens == 1) logger.warn("Picking random token for a single vnode. You should probably add more vnodes and/or use the automatic token allocation mechanism."); @@ -221,6 +221,22 @@ static Collection allocateTokens(final TokenMetadata metadata, return tokens; } + + static Collection allocateTokens(final TokenMetadata metadata, + InetAddressAndPort address, + int rf, + int numTokens, + int schemaWaitDelay) + { + StorageService.instance.waitForSchema(schemaWaitDelay); + if (!FBUtilities.getBroadcastAddressAndPort().equals(InetAddressAndPort.getLoopbackAddress())) + Gossiper.waitToSettle(); + + Collection tokens = TokenAllocation.allocateTokens(metadata, rf, address, numTokens); + BootstrapDiagnostics.tokensAllocated(address, metadata, rf, numTokens, tokens); + return tokens; + } + public static Collection getRandomTokens(TokenMetadata metadata, int numTokens) { Set tokens = new HashSet<>(numTokens); @@ -234,24 +250,4 @@ public static Collection getRandomTokens(TokenMetadata metadata, int numT logger.info("Generated random tokens. tokens are {}", tokens); return tokens; } - - public static class StringSerializer implements IVersionedSerializer - { - public static final StringSerializer instance = new StringSerializer(); - - public void serialize(String s, DataOutputPlus out, int version) throws IOException - { - out.writeUTF(s); - } - - public String deserialize(DataInputPlus in, int version) throws IOException - { - return in.readUTF(); - } - - public long serializedSize(String s, int version) - { - return TypeSizes.sizeof(s); - } - } } diff --git a/src/java/org/apache/cassandra/dht/BootstrapDiagnostics.java b/src/java/org/apache/cassandra/dht/BootstrapDiagnostics.java index 56955326de3f..5c2b46a03087 100644 --- a/src/java/org/apache/cassandra/dht/BootstrapDiagnostics.java +++ b/src/java/org/apache/cassandra/dht/BootstrapDiagnostics.java @@ -45,6 +45,7 @@ static void useSpecifiedTokens(InetAddressAndPort address, String allocationKeys address, null, allocationKeyspace, + null, numTokens, ImmutableList.copyOf(initialTokens))); } @@ -56,6 +57,7 @@ static void useRandomTokens(InetAddressAndPort address, TokenMetadata metadata, address, metadata.cloneOnlyTokenMap(), null, + null, numTokens, ImmutableList.copyOf(tokens))); } @@ -68,6 +70,20 @@ static void tokensAllocated(InetAddressAndPort address, TokenMetadata metadata, address, metadata.cloneOnlyTokenMap(), allocationKeyspace, + null, + numTokens, + ImmutableList.copyOf(tokens))); + } + + static void tokensAllocated(InetAddressAndPort address, TokenMetadata metadata, + int rf, int numTokens, Collection tokens) + { + if (isEnabled(BootstrapEventType.TOKENS_ALLOCATED)) + service.publish(new BootstrapEvent(BootstrapEventType.TOKENS_ALLOCATED, + address, + metadata.cloneOnlyTokenMap(), + null, + rf, numTokens, ImmutableList.copyOf(tokens))); } diff --git a/src/java/org/apache/cassandra/dht/BootstrapEvent.java b/src/java/org/apache/cassandra/dht/BootstrapEvent.java index 5bad09a19fa1..4936c2942ad9 100644 --- a/src/java/org/apache/cassandra/dht/BootstrapEvent.java +++ b/src/java/org/apache/cassandra/dht/BootstrapEvent.java @@ -42,16 +42,19 @@ final class BootstrapEvent extends DiagnosticEvent private final InetAddressAndPort address; @Nullable private final String allocationKeyspace; + @Nullable + private final Integer rf; private final Integer numTokens; private final Collection tokens; BootstrapEvent(BootstrapEventType type, InetAddressAndPort address, @Nullable TokenMetadata tokenMetadata, - @Nullable String allocationKeyspace, int numTokens, ImmutableCollection tokens) + @Nullable String allocationKeyspace, @Nullable Integer rf, int numTokens, ImmutableCollection tokens) { this.type = type; this.address = address; this.tokenMetadata = tokenMetadata; this.allocationKeyspace = allocationKeyspace; + this.rf = rf; this.numTokens = numTokens; this.tokens = tokens; } @@ -75,6 +78,7 @@ public Map toMap() HashMap ret = new HashMap<>(); ret.put("tokenMetadata", String.valueOf(tokenMetadata)); ret.put("allocationKeyspace", allocationKeyspace); + ret.put("rf", rf); ret.put("numTokens", numTokens); ret.put("tokens", String.valueOf(tokens)); return ret; diff --git a/src/java/org/apache/cassandra/dht/ByteOrderedPartitioner.java b/src/java/org/apache/cassandra/dht/ByteOrderedPartitioner.java index 08088f730e2e..a6314dcccc8e 100644 --- a/src/java/org/apache/cassandra/dht/ByteOrderedPartitioner.java +++ b/src/java/org/apache/cassandra/dht/ByteOrderedPartitioner.java @@ -234,6 +234,12 @@ public Token fromByteArray(ByteBuffer bytes) return new BytesToken(bytes); } + @Override + public int byteSize(Token token) + { + return ((BytesToken) token).token.length; + } + public String toString(Token token) { BytesToken bytesToken = (BytesToken) token; diff --git a/src/java/org/apache/cassandra/dht/IPartitioner.java b/src/java/org/apache/cassandra/dht/IPartitioner.java index 5475f3c9c473..ef8ced25b146 100644 --- a/src/java/org/apache/cassandra/dht/IPartitioner.java +++ b/src/java/org/apache/cassandra/dht/IPartitioner.java @@ -18,6 +18,7 @@ package org.apache.cassandra.dht; import java.nio.ByteBuffer; +import java.util.Collection; import java.util.List; import java.util.Map; import java.util.Optional; @@ -25,9 +26,29 @@ import org.apache.cassandra.db.DecoratedKey; import org.apache.cassandra.db.marshal.AbstractType; +import org.apache.cassandra.service.StorageService; public interface IPartitioner { + static IPartitioner global() + { + return StorageService.instance.getTokenMetadata().partitioner; + } + + static void validate(Collection> allBounds) + { + for (AbstractBounds bounds : allBounds) + validate(bounds); + } + + static void validate(AbstractBounds bounds) + { + if (global() != bounds.left.getPartitioner()) + throw new AssertionError(String.format("Partitioner in bounds serialization. Expected %s, was %s.", + global().getClass().getName(), + bounds.left.getPartitioner().getClass().getName())); + } + /** * Transform key to object representation of the on-disk format. * @@ -114,4 +135,9 @@ default Optional splitter() { return Optional.empty(); } + + default public int getMaxTokenSize() + { + return Integer.MIN_VALUE; + } } diff --git a/src/java/org/apache/cassandra/dht/Murmur3Partitioner.java b/src/java/org/apache/cassandra/dht/Murmur3Partitioner.java index 0f922e340388..52d0efbb5837 100644 --- a/src/java/org/apache/cassandra/dht/Murmur3Partitioner.java +++ b/src/java/org/apache/cassandra/dht/Murmur3Partitioner.java @@ -17,6 +17,7 @@ */ package org.apache.cassandra.dht; +import java.io.IOException; import java.math.BigDecimal; import java.math.BigInteger; import java.nio.ByteBuffer; @@ -25,10 +26,12 @@ import org.apache.cassandra.db.DecoratedKey; import org.apache.cassandra.db.PreHashedDecoratedKey; +import org.apache.cassandra.db.TypeSizes; import org.apache.cassandra.db.marshal.AbstractType; import org.apache.cassandra.db.marshal.PartitionerDefinedOrder; import org.apache.cassandra.db.marshal.LongType; import org.apache.cassandra.exceptions.ConfigurationException; +import org.apache.cassandra.io.util.DataOutputPlus; import org.apache.cassandra.utils.ByteBufferUtil; import org.apache.cassandra.utils.MurmurHash; import org.apache.cassandra.utils.ObjectSizes; @@ -42,6 +45,7 @@ public class Murmur3Partitioner implements IPartitioner { public static final LongToken MINIMUM = new LongToken(Long.MIN_VALUE); public static final long MAXIMUM = Long.MAX_VALUE; + private static final int MAXIMUM_TOKEN_SIZE = TypeSizes.sizeof(MAXIMUM); private static final int HEAP_SIZE = (int) ObjectSizes.measureDeep(MINIMUM); @@ -224,6 +228,11 @@ private LongToken getToken(ByteBuffer key, long[] hash) return new LongToken(normalize(hash[0])); } + public int getMaxTokenSize() + { + return MAXIMUM_TOKEN_SIZE; + } + private long[] getHash(ByteBuffer key) { long[] hash = new long[2]; @@ -300,11 +309,35 @@ public ByteBuffer toByteArray(Token token) return ByteBufferUtil.bytes(longToken.token); } + @Override + public void serialize(Token token, DataOutputPlus out) throws IOException + { + out.writeLong(((LongToken) token).token); + } + + @Override + public void serialize(Token token, ByteBuffer out) + { + out.putLong(((LongToken) token).token); + } + + @Override + public int byteSize(Token token) + { + return 8; + } + public Token fromByteArray(ByteBuffer bytes) { return new LongToken(ByteBufferUtil.toLong(bytes)); } + @Override + public Token fromByteBuffer(ByteBuffer bytes, int position, int length) + { + return new LongToken(bytes.getLong(position)); + } + public String toString(Token token) { return token.toString(); diff --git a/src/java/org/apache/cassandra/dht/RandomPartitioner.java b/src/java/org/apache/cassandra/dht/RandomPartitioner.java index 4e63475bbe47..0457a893f0c5 100644 --- a/src/java/org/apache/cassandra/dht/RandomPartitioner.java +++ b/src/java/org/apache/cassandra/dht/RandomPartitioner.java @@ -17,6 +17,7 @@ */ package org.apache.cassandra.dht; +import java.io.IOException; import java.math.BigDecimal; import java.math.BigInteger; import java.nio.ByteBuffer; @@ -31,6 +32,7 @@ import org.apache.cassandra.db.marshal.AbstractType; import org.apache.cassandra.db.marshal.IntegerType; import org.apache.cassandra.db.marshal.PartitionerDefinedOrder; +import org.apache.cassandra.io.util.DataOutputPlus; import org.apache.cassandra.utils.ByteBufferUtil; import org.apache.cassandra.utils.FBUtilities; import org.apache.cassandra.utils.GuidGenerator; @@ -46,6 +48,7 @@ public class RandomPartitioner implements IPartitioner public static final BigInteger ZERO = new BigInteger("0"); public static final BigIntegerToken MINIMUM = new BigIntegerToken("-1"); public static final BigInteger MAXIMUM = new BigInteger("2").pow(127); + public static final int MAXIMUM_TOKEN_SIZE = MAXIMUM.bitLength() / 8 + 1; /** * Maintain a separate threadlocal message digest, exclusively for token hashing. This is necessary because @@ -162,11 +165,35 @@ public ByteBuffer toByteArray(Token token) return ByteBuffer.wrap(bigIntegerToken.token.toByteArray()); } + @Override + public void serialize(Token token, DataOutputPlus out) throws IOException + { + out.write(((BigIntegerToken) token).token.toByteArray()); + } + + @Override + public void serialize(Token token, ByteBuffer out) + { + out.put(((BigIntegerToken) token).token.toByteArray()); + } + + @Override + public int byteSize(Token token) + { + return ((BigIntegerToken) token).token.bitLength() / 8 + 1; + } + public Token fromByteArray(ByteBuffer bytes) { return new BigIntegerToken(new BigInteger(ByteBufferUtil.getArray(bytes))); } + @Override + public Token fromByteBuffer(ByteBuffer bytes, int position, int length) + { + return new BigIntegerToken(new BigInteger(ByteBufferUtil.getArray(bytes, position, length))); + } + public String toString(Token token) { BigIntegerToken bigIntegerToken = (BigIntegerToken) token; @@ -252,6 +279,11 @@ public BigIntegerToken getToken(ByteBuffer key) return new BigIntegerToken(hashToBigInteger(key)); } + public int getMaxTokenSize() + { + return MAXIMUM_TOKEN_SIZE; + } + public Map describeOwnership(List sortedTokens) { Map ownerships = new HashMap(); diff --git a/src/java/org/apache/cassandra/dht/Token.java b/src/java/org/apache/cassandra/dht/Token.java index 20b45ef0f237..ccb66fd4bb3e 100644 --- a/src/java/org/apache/cassandra/dht/Token.java +++ b/src/java/org/apache/cassandra/dht/Token.java @@ -26,7 +26,6 @@ import org.apache.cassandra.db.TypeSizes; import org.apache.cassandra.exceptions.ConfigurationException; import org.apache.cassandra.io.util.DataOutputPlus; -import org.apache.cassandra.utils.ByteBufferUtil; public abstract class Token implements RingPosition, Serializable { @@ -40,8 +39,30 @@ public static abstract class TokenFactory public abstract Token fromByteArray(ByteBuffer bytes); public abstract String toString(Token token); // serialize as string, not necessarily human-readable public abstract Token fromString(String string); // deserialize - public abstract void validate(String token) throws ConfigurationException; + + public void serialize(Token token, DataOutputPlus out) throws IOException + { + out.write(toByteArray(token)); + } + + public void serialize(Token token, ByteBuffer out) throws IOException + { + out.put(toByteArray(token)); + } + + public Token fromByteBuffer(ByteBuffer bytes, int position, int length) + { + bytes = bytes.duplicate(); + bytes.position(position) + .limit(position + length); + return fromByteArray(bytes); + } + + public int byteSize(Token token) + { + return toByteArray(token).remaining(); + } } public static class TokenSerializer implements IPartitionerDependentSerializer @@ -49,23 +70,28 @@ public static class TokenSerializer implements IPartitionerDependentSerializer allocateTokens(final TokenMetadata tokenMetadata return tokens; } + public static Collection allocateTokens(final TokenMetadata tokenMetadata, + final int replicas, + final InetAddressAndPort endpoint, + int numTokens) + { + TokenMetadata tokenMetadataCopy = tokenMetadata.cloneOnlyTokenMap(); + StrategyAdapter strategy = getStrategy(tokenMetadataCopy, replicas, endpoint); + Collection tokens = create(tokenMetadata, strategy).addUnit(endpoint, numTokens); + tokens = adjustForCrossDatacenterClashes(tokenMetadata, strategy, tokens); + logger.warn("Selected tokens {}", tokens); + // SummaryStatistics is not implemented for `allocate_tokens_for_local_replication_factor` + return tokens; + } + private static Collection adjustForCrossDatacenterClashes(final TokenMetadata tokenMetadata, StrategyAdapter strategy, Collection tokens) { @@ -197,7 +212,17 @@ static StrategyAdapter getStrategy(final TokenMetadata tokenMetadata, final Netw { final String dc = snitch.getDatacenter(endpoint); final int replicas = rs.getReplicationFactor(dc).allReplicas; + return getStrategy(tokenMetadata, replicas, snitch, endpoint); + } + + static StrategyAdapter getStrategy(final TokenMetadata tokenMetadata, final int replicas, final InetAddressAndPort endpoint) + { + return getStrategy(tokenMetadata, replicas, DatabaseDescriptor.getEndpointSnitch(), endpoint); + } + static StrategyAdapter getStrategy(final TokenMetadata tokenMetadata, final int replicas, final IEndpointSnitch snitch, final InetAddressAndPort endpoint) + { + final String dc = snitch.getDatacenter(endpoint); if (replicas == 0 || replicas == 1) { // No replication, each node is treated as separate. @@ -224,7 +249,11 @@ public boolean inAllocationRing(InetAddressAndPort other) } Topology topology = tokenMetadata.getTopology(); - int racks = topology.getDatacenterRacks().get(dc).asMap().size(); + + // if topology hasn't been setup yet for this endpoint+rack then treat it as a separate unit + int racks = topology.getDatacenterRacks().get(dc) != null && topology.getDatacenterRacks().get(dc).containsKey(snitch.getRack(endpoint)) + ? topology.getDatacenterRacks().get(dc).asMap().size() + : 1; if (racks >= replicas) { diff --git a/src/java/org/apache/cassandra/exceptions/AuthenticationException.java b/src/java/org/apache/cassandra/exceptions/AuthenticationException.java index ce6cb2c7602b..067f3ae2f43e 100644 --- a/src/java/org/apache/cassandra/exceptions/AuthenticationException.java +++ b/src/java/org/apache/cassandra/exceptions/AuthenticationException.java @@ -23,4 +23,9 @@ public AuthenticationException(String msg) { super(ExceptionCode.BAD_CREDENTIALS, msg); } + + public AuthenticationException(String msg, Throwable e) + { + super(ExceptionCode.BAD_CREDENTIALS, msg, e); + } } diff --git a/src/java/org/apache/cassandra/net/async/ExpiredException.java b/src/java/org/apache/cassandra/exceptions/IncompatibleSchemaException.java similarity index 71% rename from src/java/org/apache/cassandra/net/async/ExpiredException.java rename to src/java/org/apache/cassandra/exceptions/IncompatibleSchemaException.java index 191900c4b4f3..fe3a167b6f72 100644 --- a/src/java/org/apache/cassandra/net/async/ExpiredException.java +++ b/src/java/org/apache/cassandra/exceptions/IncompatibleSchemaException.java @@ -15,14 +15,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +package org.apache.cassandra.exceptions; -package org.apache.cassandra.net.async; +import java.io.IOException; -/** - * Thrown when a {@link QueuedMessage} has timed out (has sat in the netty outbound channel for too long). - */ -class ExpiredException extends Exception +public class IncompatibleSchemaException extends IOException { - @SuppressWarnings("ThrowableInstanceNeverThrown") - static final ExpiredException INSTANCE = new ExpiredException(); + public IncompatibleSchemaException(String msg) + { + super(msg); + } } diff --git a/src/java/org/apache/cassandra/exceptions/RequestFailureException.java b/src/java/org/apache/cassandra/exceptions/RequestFailureException.java index 2b57a7517421..e982b44f02e7 100644 --- a/src/java/org/apache/cassandra/exceptions/RequestFailureException.java +++ b/src/java/org/apache/cassandra/exceptions/RequestFailureException.java @@ -19,6 +19,7 @@ import java.util.HashMap; import java.util.Map; +import java.util.stream.Collectors; import org.apache.cassandra.db.ConsistencyLevel; import org.apache.cassandra.locator.InetAddressAndPort; @@ -32,7 +33,10 @@ public class RequestFailureException extends RequestExecutionException protected RequestFailureException(ExceptionCode code, ConsistencyLevel consistency, int received, int blockFor, Map failureReasonByEndpoint) { - super(code, String.format("Operation failed - received %d responses and %d failures", received, failureReasonByEndpoint.size())); + super(code, String.format("Operation failed - received %d responses and %d failures: %s", + received, + failureReasonByEndpoint.size(), + buildFailureString(failureReasonByEndpoint))); this.consistency = consistency; this.received = received; this.blockFor = blockFor; @@ -45,4 +49,11 @@ protected RequestFailureException(ExceptionCode code, ConsistencyLevel consisten // we encode this map for transport. this.failureReasonByEndpoint = new HashMap<>(failureReasonByEndpoint); } + + private static String buildFailureString(Map failures) + { + return failures.entrySet().stream() + .map(e -> String.format("%s from %s", e.getValue(), e.getKey())) + .collect(Collectors.joining(", ")); + } } diff --git a/src/java/org/apache/cassandra/exceptions/RequestFailureReason.java b/src/java/org/apache/cassandra/exceptions/RequestFailureReason.java index 96ab7b5d0cfc..1cdbdb544d28 100644 --- a/src/java/org/apache/cassandra/exceptions/RequestFailureReason.java +++ b/src/java/org/apache/cassandra/exceptions/RequestFailureReason.java @@ -15,37 +15,101 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package org.apache.cassandra.exceptions; +import java.io.IOException; + +import com.google.common.primitives.Ints; + +import org.apache.cassandra.db.filter.TombstoneOverwhelmingException; +import org.apache.cassandra.io.IVersionedSerializer; +import org.apache.cassandra.io.util.DataInputPlus; +import org.apache.cassandra.io.util.DataOutputPlus; +import org.apache.cassandra.utils.vint.VIntCoding; + +import static java.lang.Math.max; +import static org.apache.cassandra.net.MessagingService.VERSION_40; + public enum RequestFailureReason { - /** - * The reason for the failure was none of the below reasons or was not recorded by the data node. - */ - UNKNOWN (0x0000), + UNKNOWN (0), + READ_TOO_MANY_TOMBSTONES (1), + TIMEOUT (2), + INCOMPATIBLE_SCHEMA (3); - /** - * The data node read too many tombstones when attempting to execute a read query (see tombstone_failure_threshold). - */ - READ_TOO_MANY_TOMBSTONES (0x0001); + public static final Serializer serializer = new Serializer(); - /** The code to be serialized as an unsigned 16 bit integer */ public final int code; - public static final RequestFailureReason[] VALUES = values(); - RequestFailureReason(final int code) + RequestFailureReason(int code) { this.code = code; } - public static RequestFailureReason fromCode(final int code) + private static final RequestFailureReason[] codeToReasonMap; + + static { - for (RequestFailureReason reasonCode : VALUES) + RequestFailureReason[] reasons = values(); + + int max = -1; + for (RequestFailureReason r : reasons) + max = max(r.code, max); + + RequestFailureReason[] codeMap = new RequestFailureReason[max + 1]; + + for (RequestFailureReason reason : reasons) + { + if (codeMap[reason.code] != null) + throw new RuntimeException("Two RequestFailureReason-s that map to the same code: " + reason.code); + codeMap[reason.code] = reason; + } + + codeToReasonMap = codeMap; + } + + public static RequestFailureReason fromCode(int code) + { + if (code < 0) + throw new IllegalArgumentException("RequestFailureReason code must be non-negative (got " + code + ')'); + + // be forgiving and return UNKNOWN if we aren't aware of the code - for forward compatibility + return code < codeToReasonMap.length ? codeToReasonMap[code] : UNKNOWN; + } + + public static RequestFailureReason forException(Throwable t) + { + if (t instanceof TombstoneOverwhelmingException) + return READ_TOO_MANY_TOMBSTONES; + + if (t instanceof IncompatibleSchemaException) + return INCOMPATIBLE_SCHEMA; + + return UNKNOWN; + } + + public static final class Serializer implements IVersionedSerializer + { + private Serializer() + { + } + + public void serialize(RequestFailureReason reason, DataOutputPlus out, int version) throws IOException + { + if (version < VERSION_40) + out.writeShort(reason.code); + else + out.writeUnsignedVInt(reason.code); + } + + public RequestFailureReason deserialize(DataInputPlus in, int version) throws IOException + { + return fromCode(version < VERSION_40 ? in.readUnsignedShort() : Ints.checkedCast(in.readUnsignedVInt())); + } + + public long serializedSize(RequestFailureReason reason, int version) { - if (reasonCode.code == code) - return reasonCode; + return version < VERSION_40 ? 2 : VIntCoding.computeVIntSize(reason.code); } - throw new IllegalArgumentException("Unknown request failure reason error code: " + code); } } diff --git a/src/java/org/apache/cassandra/exceptions/UnauthorizedException.java b/src/java/org/apache/cassandra/exceptions/UnauthorizedException.java index 12a3f8af6767..008d793545ea 100644 --- a/src/java/org/apache/cassandra/exceptions/UnauthorizedException.java +++ b/src/java/org/apache/cassandra/exceptions/UnauthorizedException.java @@ -23,4 +23,9 @@ public UnauthorizedException(String msg) { super(ExceptionCode.UNAUTHORIZED, msg); } + + public UnauthorizedException(String msg, Throwable e) + { + super(ExceptionCode.UNAUTHORIZED, msg, e); + } } diff --git a/src/java/org/apache/cassandra/exceptions/UnknownColumnException.java b/src/java/org/apache/cassandra/exceptions/UnknownColumnException.java new file mode 100644 index 000000000000..93a464e77e02 --- /dev/null +++ b/src/java/org/apache/cassandra/exceptions/UnknownColumnException.java @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.cassandra.exceptions; + +public final class UnknownColumnException extends IncompatibleSchemaException +{ + public UnknownColumnException(String msg) + { + super(msg); + } +} diff --git a/src/java/org/apache/cassandra/exceptions/UnknownTableException.java b/src/java/org/apache/cassandra/exceptions/UnknownTableException.java index 2cd7aab21121..3e9c77537061 100644 --- a/src/java/org/apache/cassandra/exceptions/UnknownTableException.java +++ b/src/java/org/apache/cassandra/exceptions/UnknownTableException.java @@ -17,11 +17,9 @@ */ package org.apache.cassandra.exceptions; -import java.io.IOException; - import org.apache.cassandra.schema.TableId; -public class UnknownTableException extends IOException +public class UnknownTableException extends IncompatibleSchemaException { public final TableId id; diff --git a/src/java/org/apache/cassandra/gms/EchoMessage.java b/src/java/org/apache/cassandra/gms/EchoMessage.java deleted file mode 100644 index 2fee889f6b72..000000000000 --- a/src/java/org/apache/cassandra/gms/EchoMessage.java +++ /dev/null @@ -1,56 +0,0 @@ -package org.apache.cassandra.gms; -/* - * - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - * - */ - - -import java.io.IOException; - -import org.apache.cassandra.io.IVersionedSerializer; -import org.apache.cassandra.io.util.DataInputPlus; -import org.apache.cassandra.io.util.DataOutputPlus; - -public final class EchoMessage -{ - public static final EchoMessage instance = new EchoMessage(); - - public static final IVersionedSerializer serializer = new EchoMessageSerializer(); - - private EchoMessage() - { - } - - public static class EchoMessageSerializer implements IVersionedSerializer - { - public void serialize(EchoMessage t, DataOutputPlus out, int version) throws IOException - { - } - - public EchoMessage deserialize(DataInputPlus in, int version) throws IOException - { - return EchoMessage.instance; - } - - public long serializedSize(EchoMessage t, int version) - { - return 0; - } - } -} diff --git a/src/java/org/apache/cassandra/gms/FailureDetector.java b/src/java/org/apache/cassandra/gms/FailureDetector.java index 4a16f2a9e8c2..d3a5f340080b 100644 --- a/src/java/org/apache/cassandra/gms/FailureDetector.java +++ b/src/java/org/apache/cassandra/gms/FailureDetector.java @@ -37,10 +37,11 @@ import org.apache.cassandra.config.DatabaseDescriptor; import org.apache.cassandra.io.FSWriteError; import org.apache.cassandra.locator.InetAddressAndPort; -import org.apache.cassandra.utils.Clock; import org.apache.cassandra.utils.FBUtilities; import org.apache.cassandra.utils.MBeanWrapper; +import static org.apache.cassandra.utils.MonotonicClock.preciseTime; + /** * This FailureDetector is an implementation of the paper titled * "The Phi Accrual Failure Detector" by Hayashibara. @@ -55,7 +56,7 @@ public class FailureDetector implements IFailureDetector, FailureDetectorMBean private static final int DEBUG_PERCENTAGE = 80; // if the phi is larger than this percentage of the max, log a debug message private static final long DEFAULT_MAX_PAUSE = 5000L * 1000000L; // 5 seconds private static final long MAX_LOCAL_PAUSE_IN_NANOS = getMaxLocalPause(); - private long lastInterpret = Clock.instance.nanoTime(); + private long lastInterpret = preciseTime.now(); private long lastPause = 0L; private static long getMaxLocalPause() @@ -283,7 +284,7 @@ public boolean isAlive(InetAddressAndPort ep) public void report(InetAddressAndPort ep) { - long now = Clock.instance.nanoTime(); + long now = preciseTime.now(); ArrivalWindow heartbeatWindow = arrivalSamples.get(ep); if (heartbeatWindow == null) { @@ -310,7 +311,7 @@ public void interpret(InetAddressAndPort ep) { return; } - long now = Clock.instance.nanoTime(); + long now = preciseTime.now(); long diff = now - lastInterpret; lastInterpret = now; if (diff > MAX_LOCAL_PAUSE_IN_NANOS) @@ -319,7 +320,7 @@ public void interpret(InetAddressAndPort ep) lastPause = now; return; } - if (Clock.instance.nanoTime() - lastPause < MAX_LOCAL_PAUSE_IN_NANOS) + if (preciseTime.now() - lastPause < MAX_LOCAL_PAUSE_IN_NANOS) { logger.debug("Still not marking nodes down due to local pause"); return; diff --git a/src/java/org/apache/cassandra/gms/GossipDigest.java b/src/java/org/apache/cassandra/gms/GossipDigest.java index c7e60c4c3443..53f6c5c52c59 100644 --- a/src/java/org/apache/cassandra/gms/GossipDigest.java +++ b/src/java/org/apache/cassandra/gms/GossipDigest.java @@ -24,7 +24,8 @@ import org.apache.cassandra.io.util.DataInputPlus; import org.apache.cassandra.io.util.DataOutputPlus; import org.apache.cassandra.locator.InetAddressAndPort; -import org.apache.cassandra.net.CompactEndpointSerializationHelper; + +import static org.apache.cassandra.locator.InetAddressAndPort.Serializer.inetAddressAndPortSerializer; /** * Contains information about a specified list of Endpoints and the largest version @@ -83,14 +84,14 @@ class GossipDigestSerializer implements IVersionedSerializer { public void serialize(GossipDigest gDigest, DataOutputPlus out, int version) throws IOException { - CompactEndpointSerializationHelper.instance.serialize(gDigest.endpoint, out, version); + inetAddressAndPortSerializer.serialize(gDigest.endpoint, out, version); out.writeInt(gDigest.generation); out.writeInt(gDigest.maxVersion); } public GossipDigest deserialize(DataInputPlus in, int version) throws IOException { - InetAddressAndPort endpoint = CompactEndpointSerializationHelper.instance.deserialize(in, version); + InetAddressAndPort endpoint = inetAddressAndPortSerializer.deserialize(in, version); int generation = in.readInt(); int maxVersion = in.readInt(); return new GossipDigest(endpoint, generation, maxVersion); @@ -98,7 +99,7 @@ public GossipDigest deserialize(DataInputPlus in, int version) throws IOExceptio public long serializedSize(GossipDigest gDigest, int version) { - long size = CompactEndpointSerializationHelper.instance.serializedSize(gDigest.endpoint, version); + long size = inetAddressAndPortSerializer.serializedSize(gDigest.endpoint, version); size += TypeSizes.sizeof(gDigest.generation); size += TypeSizes.sizeof(gDigest.maxVersion); return size; diff --git a/src/java/org/apache/cassandra/gms/GossipDigestAck.java b/src/java/org/apache/cassandra/gms/GossipDigestAck.java index a7d5b92b303f..26494eaba9d4 100644 --- a/src/java/org/apache/cassandra/gms/GossipDigestAck.java +++ b/src/java/org/apache/cassandra/gms/GossipDigestAck.java @@ -27,7 +27,8 @@ import org.apache.cassandra.io.util.DataInputPlus; import org.apache.cassandra.io.util.DataOutputPlus; import org.apache.cassandra.locator.InetAddressAndPort; -import org.apache.cassandra.net.CompactEndpointSerializationHelper; + +import static org.apache.cassandra.locator.InetAddressAndPort.Serializer.inetAddressAndPortSerializer; /** * This ack gets sent out as a result of the receipt of a GossipDigestSynMessage by an @@ -66,7 +67,7 @@ public void serialize(GossipDigestAck gDigestAckMessage, DataOutputPlus out, int for (Map.Entry entry : gDigestAckMessage.epStateMap.entrySet()) { InetAddressAndPort ep = entry.getKey(); - CompactEndpointSerializationHelper.instance.serialize(ep, out, version); + inetAddressAndPortSerializer.serialize(ep, out, version); EndpointState.serializer.serialize(entry.getValue(), out, version); } } @@ -79,7 +80,7 @@ public GossipDigestAck deserialize(DataInputPlus in, int version) throws IOExcep for (int i = 0; i < size; ++i) { - InetAddressAndPort ep = CompactEndpointSerializationHelper.instance.deserialize(in, version); + InetAddressAndPort ep = inetAddressAndPortSerializer.deserialize(in, version); EndpointState epState = EndpointState.serializer.deserialize(in, version); epStateMap.put(ep, epState); } @@ -91,7 +92,7 @@ public long serializedSize(GossipDigestAck ack, int version) int size = GossipDigestSerializationHelper.serializedSize(ack.gDigestList, version); size += TypeSizes.sizeof(ack.epStateMap.size()); for (Map.Entry entry : ack.epStateMap.entrySet()) - size += CompactEndpointSerializationHelper.instance.serializedSize(entry.getKey(), version) + size += inetAddressAndPortSerializer.serializedSize(entry.getKey(), version) + EndpointState.serializer.serializedSize(entry.getValue(), version); return size; } diff --git a/src/java/org/apache/cassandra/gms/GossipDigestAck2.java b/src/java/org/apache/cassandra/gms/GossipDigestAck2.java index a6d1d2b196d3..0e4062bb0f44 100644 --- a/src/java/org/apache/cassandra/gms/GossipDigestAck2.java +++ b/src/java/org/apache/cassandra/gms/GossipDigestAck2.java @@ -26,7 +26,8 @@ import org.apache.cassandra.io.util.DataInputPlus; import org.apache.cassandra.io.util.DataOutputPlus; import org.apache.cassandra.locator.InetAddressAndPort; -import org.apache.cassandra.net.CompactEndpointSerializationHelper; + +import static org.apache.cassandra.locator.InetAddressAndPort.Serializer.inetAddressAndPortSerializer; /** * This ack gets sent out as a result of the receipt of a GossipDigestAckMessage. This the @@ -57,7 +58,7 @@ public void serialize(GossipDigestAck2 ack2, DataOutputPlus out, int version) th for (Map.Entry entry : ack2.epStateMap.entrySet()) { InetAddressAndPort ep = entry.getKey(); - CompactEndpointSerializationHelper.instance.serialize(ep, out, version); + inetAddressAndPortSerializer.serialize(ep, out, version); EndpointState.serializer.serialize(entry.getValue(), out, version); } } @@ -69,7 +70,7 @@ public GossipDigestAck2 deserialize(DataInputPlus in, int version) throws IOExce for (int i = 0; i < size; ++i) { - InetAddressAndPort ep = CompactEndpointSerializationHelper.instance.deserialize(in, version); + InetAddressAndPort ep = inetAddressAndPortSerializer.deserialize(in, version); EndpointState epState = EndpointState.serializer.deserialize(in, version); epStateMap.put(ep, epState); } @@ -80,7 +81,7 @@ public long serializedSize(GossipDigestAck2 ack2, int version) { long size = TypeSizes.sizeof(ack2.epStateMap.size()); for (Map.Entry entry : ack2.epStateMap.entrySet()) - size += CompactEndpointSerializationHelper.instance.serializedSize(entry.getKey(), version) + size += inetAddressAndPortSerializer.serializedSize(entry.getKey(), version) + EndpointState.serializer.serializedSize(entry.getValue(), version); return size; } diff --git a/src/java/org/apache/cassandra/gms/GossipDigestAck2VerbHandler.java b/src/java/org/apache/cassandra/gms/GossipDigestAck2VerbHandler.java index fd5d4876b45f..58c1589eca91 100644 --- a/src/java/org/apache/cassandra/gms/GossipDigestAck2VerbHandler.java +++ b/src/java/org/apache/cassandra/gms/GossipDigestAck2VerbHandler.java @@ -23,18 +23,19 @@ import org.slf4j.LoggerFactory; import org.apache.cassandra.locator.InetAddressAndPort; -import org.apache.cassandra.net.IVerbHandler; -import org.apache.cassandra.net.MessageIn; +import org.apache.cassandra.net.Message; -public class GossipDigestAck2VerbHandler implements IVerbHandler +public class GossipDigestAck2VerbHandler extends GossipVerbHandler { + public static final GossipDigestAck2VerbHandler instance = new GossipDigestAck2VerbHandler(); + private static final Logger logger = LoggerFactory.getLogger(GossipDigestAck2VerbHandler.class); - public void doVerb(MessageIn message, int id) + public void doVerb(Message message) { if (logger.isTraceEnabled()) { - InetAddressAndPort from = message.from; + InetAddressAndPort from = message.from(); logger.trace("Received a GossipDigestAck2Message from {}", from); } if (!Gossiper.instance.isEnabled()) @@ -47,5 +48,7 @@ public void doVerb(MessageIn message, int id) /* Notify the Failure Detector */ Gossiper.instance.notifyFailureDetector(remoteEpStateMap); Gossiper.instance.applyStateLocally(remoteEpStateMap); + + super.doVerb(message); } } diff --git a/src/java/org/apache/cassandra/gms/GossipDigestAckVerbHandler.java b/src/java/org/apache/cassandra/gms/GossipDigestAckVerbHandler.java index 2a12b7c3995a..1e8604b66606 100644 --- a/src/java/org/apache/cassandra/gms/GossipDigestAckVerbHandler.java +++ b/src/java/org/apache/cassandra/gms/GossipDigestAckVerbHandler.java @@ -25,18 +25,20 @@ import org.slf4j.LoggerFactory; import org.apache.cassandra.locator.InetAddressAndPort; -import org.apache.cassandra.net.IVerbHandler; -import org.apache.cassandra.net.MessageIn; -import org.apache.cassandra.net.MessageOut; +import org.apache.cassandra.net.Message; import org.apache.cassandra.net.MessagingService; -public class GossipDigestAckVerbHandler implements IVerbHandler +import static org.apache.cassandra.net.Verb.GOSSIP_DIGEST_ACK2; + +public class GossipDigestAckVerbHandler extends GossipVerbHandler { + public static final GossipDigestAckVerbHandler instance = new GossipDigestAckVerbHandler(); + private static final Logger logger = LoggerFactory.getLogger(GossipDigestAckVerbHandler.class); - public void doVerb(MessageIn message, int id) + public void doVerb(Message message) { - InetAddressAndPort from = message.from; + InetAddressAndPort from = message.from(); if (logger.isTraceEnabled()) logger.trace("Received a GossipDigestAckMessage from {}", from); if (!Gossiper.instance.isEnabled() && !Gossiper.instance.isInShadowRound()) @@ -88,11 +90,11 @@ public void doVerb(MessageIn message, int id) deltaEpStateMap.put(addr, localEpStatePtr); } - MessageOut gDigestAck2Message = new MessageOut(MessagingService.Verb.GOSSIP_DIGEST_ACK2, - new GossipDigestAck2(deltaEpStateMap), - GossipDigestAck2.serializer); + Message gDigestAck2Message = Message.out(GOSSIP_DIGEST_ACK2, new GossipDigestAck2(deltaEpStateMap)); if (logger.isTraceEnabled()) logger.trace("Sending a GossipDigestAck2Message to {}", from); - MessagingService.instance().sendOneWay(gDigestAck2Message, from); + MessagingService.instance().send(gDigestAck2Message, from); + + super.doVerb(message); } } diff --git a/src/java/org/apache/cassandra/gms/GossipDigestSynVerbHandler.java b/src/java/org/apache/cassandra/gms/GossipDigestSynVerbHandler.java index b06c24dcdf54..520dbec3f606 100644 --- a/src/java/org/apache/cassandra/gms/GossipDigestSynVerbHandler.java +++ b/src/java/org/apache/cassandra/gms/GossipDigestSynVerbHandler.java @@ -24,18 +24,20 @@ import org.apache.cassandra.config.DatabaseDescriptor; import org.apache.cassandra.locator.InetAddressAndPort; -import org.apache.cassandra.net.IVerbHandler; -import org.apache.cassandra.net.MessageIn; -import org.apache.cassandra.net.MessageOut; +import org.apache.cassandra.net.Message; import org.apache.cassandra.net.MessagingService; -public class GossipDigestSynVerbHandler implements IVerbHandler +import static org.apache.cassandra.net.Verb.*; + +public class GossipDigestSynVerbHandler extends GossipVerbHandler { + public static final GossipDigestSynVerbHandler instance = new GossipDigestSynVerbHandler(); + private static final Logger logger = LoggerFactory.getLogger(GossipDigestSynVerbHandler.class); - public void doVerb(MessageIn message, int id) + public void doVerb(Message message) { - InetAddressAndPort from = message.from; + InetAddressAndPort from = message.from(); if (logger.isTraceEnabled()) logger.trace("Received a GossipDigestSynMessage from {}", from); if (!Gossiper.instance.isEnabled() && !Gossiper.instance.isInShadowRound()) @@ -79,10 +81,8 @@ public void doVerb(MessageIn message, int id) logger.debug("Received a shadow round syn from {}. Gossip is disabled but " + "currently also in shadow round, responding with a minimal ack", from); MessagingService.instance() - .sendOneWay(new MessageOut<>(MessagingService.Verb.GOSSIP_DIGEST_ACK, - new GossipDigestAck(new ArrayList<>(), new HashMap<>()), - GossipDigestAck.serializer), - from); + .send(Message.out(GOSSIP_DIGEST_ACK, new GossipDigestAck(Collections.emptyList(), Collections.emptyMap())), + from); return; } @@ -101,11 +101,11 @@ public void doVerb(MessageIn message, int id) Map deltaEpStateMap = new HashMap(); Gossiper.instance.examineGossiper(gDigestList, deltaGossipDigestList, deltaEpStateMap); logger.trace("sending {} digests and {} deltas", deltaGossipDigestList.size(), deltaEpStateMap.size()); - MessageOut gDigestAckMessage = new MessageOut(MessagingService.Verb.GOSSIP_DIGEST_ACK, - new GossipDigestAck(deltaGossipDigestList, deltaEpStateMap), - GossipDigestAck.serializer); + Message gDigestAckMessage = Message.out(GOSSIP_DIGEST_ACK, new GossipDigestAck(deltaGossipDigestList, deltaEpStateMap)); if (logger.isTraceEnabled()) logger.trace("Sending a GossipDigestAckMessage to {}", from); - MessagingService.instance().sendOneWay(gDigestAckMessage, from); + MessagingService.instance().send(gDigestAckMessage, from); + + super.doVerb(message); } } diff --git a/src/java/org/apache/cassandra/gms/GossipShutdownVerbHandler.java b/src/java/org/apache/cassandra/gms/GossipShutdownVerbHandler.java index 169110733c50..83c8568274f9 100644 --- a/src/java/org/apache/cassandra/gms/GossipShutdownVerbHandler.java +++ b/src/java/org/apache/cassandra/gms/GossipShutdownVerbHandler.java @@ -18,23 +18,25 @@ package org.apache.cassandra.gms; import org.apache.cassandra.net.IVerbHandler; -import org.apache.cassandra.net.MessageIn; +import org.apache.cassandra.net.Message; import org.slf4j.Logger; import org.slf4j.LoggerFactory; public class GossipShutdownVerbHandler implements IVerbHandler { + public static final GossipShutdownVerbHandler instance = new GossipShutdownVerbHandler(); + private static final Logger logger = LoggerFactory.getLogger(GossipShutdownVerbHandler.class); - public void doVerb(MessageIn message, int id) + public void doVerb(Message message) { if (!Gossiper.instance.isEnabled()) { - logger.debug("Ignoring shutdown message from {} because gossip is disabled", message.from); + logger.debug("Ignoring shutdown message from {} because gossip is disabled", message.from()); return; } - Gossiper.instance.markAsShutdown(message.from); + Gossiper.instance.markAsShutdown(message.from()); } } \ No newline at end of file diff --git a/src/java/org/apache/cassandra/gms/GossipVerbHandler.java b/src/java/org/apache/cassandra/gms/GossipVerbHandler.java new file mode 100644 index 000000000000..02aeaf4467c8 --- /dev/null +++ b/src/java/org/apache/cassandra/gms/GossipVerbHandler.java @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.gms; + +import org.apache.cassandra.net.IVerbHandler; +import org.apache.cassandra.net.Message; + +public class GossipVerbHandler implements IVerbHandler +{ + public void doVerb(Message message) + { + Gossiper.instance.setLastProcessedMessageAt(message.creationTimeMillis()); + } +} diff --git a/src/java/org/apache/cassandra/gms/Gossiper.java b/src/java/org/apache/cassandra/gms/Gossiper.java index 8955bf950673..71f44622eb68 100644 --- a/src/java/org/apache/cassandra/gms/Gossiper.java +++ b/src/java/org/apache/cassandra/gms/Gossiper.java @@ -27,6 +27,7 @@ import java.util.stream.Collectors; import javax.annotation.Nullable; +import java.util.stream.Collectors; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Suppliers; @@ -39,8 +40,11 @@ import com.google.common.util.concurrent.Uninterruptibles; import org.apache.cassandra.locator.InetAddressAndPort; +import org.apache.cassandra.net.NoPayload; +import org.apache.cassandra.net.Verb; import org.apache.cassandra.utils.CassandraVersion; import io.netty.util.concurrent.FastThreadLocal; +import org.apache.cassandra.utils.ExecutorUtils; import org.apache.cassandra.utils.MBeanWrapper; import org.apache.cassandra.utils.NoSpamLogger; import org.apache.cassandra.utils.Pair; @@ -53,13 +57,18 @@ import org.apache.cassandra.concurrent.StageManager; import org.apache.cassandra.config.DatabaseDescriptor; import org.apache.cassandra.dht.Token; -import org.apache.cassandra.net.IAsyncCallback; -import org.apache.cassandra.net.MessageIn; -import org.apache.cassandra.net.MessageOut; +import org.apache.cassandra.net.RequestCallback; +import org.apache.cassandra.net.Message; import org.apache.cassandra.net.MessagingService; import org.apache.cassandra.service.StorageService; import org.apache.cassandra.utils.FBUtilities; import org.apache.cassandra.utils.JVMStabilityInspector; +import static org.apache.cassandra.utils.ExecutorUtils.awaitTermination; +import static org.apache.cassandra.utils.ExecutorUtils.shutdown; + +import static org.apache.cassandra.net.NoPayload.noPayload; +import static org.apache.cassandra.net.Verb.ECHO_REQ; +import static org.apache.cassandra.net.Verb.GOSSIP_DIGEST_SYN; /** * This module is responsible for Gossiping information for the local endpoint. This abstraction @@ -236,9 +245,7 @@ public void run() GossipDigestSyn digestSynMessage = new GossipDigestSyn(DatabaseDescriptor.getClusterName(), DatabaseDescriptor.getPartitionerName(), gDigests); - MessageOut message = new MessageOut(MessagingService.Verb.GOSSIP_DIGEST_SYN, - digestSynMessage, - GossipDigestSyn.serializer); + Message message = Message.out(GOSSIP_DIGEST_SYN, digestSynMessage); /* Gossip to some random live member */ boolean gossipedToSeed = doGossipToLiveMember(message); @@ -545,11 +552,11 @@ public void removeEndpoint(InetAddressAndPort endpoint) liveEndpoints.remove(endpoint); unreachableEndpoints.remove(endpoint); - MessagingService.instance().resetVersion(endpoint); + MessagingService.instance().versions.reset(endpoint); quarantineEndpoint(endpoint); - MessagingService.instance().destroyConnectionPool(endpoint); - if (logger.isDebugEnabled()) - logger.debug("removing endpoint {}", endpoint); + MessagingService.instance().closeOutbound(endpoint); + MessagingService.instance().removeInbound(endpoint); + logger.debug("removing endpoint {}", endpoint); GossiperDiagnostics.removedEndpoint(this, endpoint); } @@ -777,7 +784,7 @@ public int getCurrentGenerationNumber(InetAddressAndPort endpoint) * @param epSet a set of endpoint from which a random endpoint is chosen. * @return true if the chosen endpoint is also a seed. */ - private boolean sendGossip(MessageOut message, Set epSet) + private boolean sendGossip(Message message, Set epSet) { List liveEndpoints = ImmutableList.copyOf(epSet); @@ -791,7 +798,7 @@ private boolean sendGossip(MessageOut message, Set message, Set message) + private boolean doGossipToLiveMember(Message message) { int size = liveEndpoints.size(); if (size == 0) @@ -808,7 +815,7 @@ private boolean doGossipToLiveMember(MessageOut message) } /* Sends a Gossip message to an unreachable member */ - private void maybeGossipToUnreachableMember(MessageOut message) + private void maybeGossipToUnreachableMember(Message message) { double liveEndpointCount = liveEndpoints.size(); double unreachableEndpointCount = unreachableEndpoints.size(); @@ -823,7 +830,7 @@ private void maybeGossipToUnreachableMember(MessageOut message) } /* Possibly gossip to a seed for facilitating partition healing */ - private void maybeGossipToSeed(MessageOut prod) + private void maybeGossipToSeed(Message prod) { int size = seeds.size(); if (size > 0) @@ -1145,23 +1152,15 @@ private void markAlive(final InetAddressAndPort addr, final EndpointState localS { localState.markDead(); - MessageOut echoMessage = new MessageOut(MessagingService.Verb.ECHO, EchoMessage.instance, EchoMessage.serializer); - logger.trace("Sending a EchoMessage to {}", addr); - IAsyncCallback echoHandler = new IAsyncCallback() + Message echoMessage = Message.out(ECHO_REQ, noPayload); + logger.trace("Sending ECHO_REQ to {}", addr); + RequestCallback echoHandler = msg -> { - public boolean isLatencyForSnitch() - { - return false; - } - - public void response(MessageIn msg) - { - // force processing of the echo response onto the gossip stage, as it comes in on the REQUEST_RESPONSE stage - runInGossipStageBlocking(() -> realMarkAlive(addr, localState)); - } + // force processing of the echo response onto the gossip stage, as it comes in on the REQUEST_RESPONSE stage + runInGossipStageBlocking(() -> realMarkAlive(addr, localState)); }; - MessagingService.instance().sendRR(echoMessage, addr, echoHandler); + MessagingService.instance().sendWithCallback(echoMessage, addr, echoHandler); GossiperDiagnostics.markedAlive(this, addr, localState); } @@ -1382,25 +1381,41 @@ private void applyNewStates(InetAddressAndPort addr, EndpointState localState, E Set> remoteStates = remoteState.states(); assert remoteState.getHeartBeatState().getGeneration() == localState.getHeartBeatState().getGeneration(); - localState.addApplicationStates(remoteStates); - - //Filter out pre-4.0 versions of data for more complete 4.0 versions - Set> filtered = remoteStates.stream().filter(entry -> { - switch (entry.getKey()) - { - case INTERNAL_IP: - return remoteState.getApplicationState(ApplicationState.INTERNAL_ADDRESS_AND_PORT) == null; - case STATUS: - return remoteState.getApplicationState(ApplicationState.STATUS_WITH_PORT) == null; - case RPC_ADDRESS: - return remoteState.getApplicationState(ApplicationState.NATIVE_ADDRESS_AND_PORT) == null; - default: - return true; - } + + + Set> updatedStates = remoteStates.stream().filter(entry -> { + // Filter out pre-4.0 versions of data for more complete 4.0 versions + switch (entry.getKey()) + { + case INTERNAL_IP: + if (remoteState.getApplicationState(ApplicationState.INTERNAL_ADDRESS_AND_PORT) != null) return false; + break; + case STATUS: + if (remoteState.getApplicationState(ApplicationState.STATUS_WITH_PORT) != null) return false; + break; + case RPC_ADDRESS: + if (remoteState.getApplicationState(ApplicationState.NATIVE_ADDRESS_AND_PORT) != null) return false; + break; + default: + break; + } + + // filter out the states that are already up to date (has the same or higher version) + VersionedValue local = localState.getApplicationState(entry.getKey()); + return (local == null || local.version < entry.getValue().version); }).collect(Collectors.toSet()); - for (Entry remoteEntry : filtered) - doOnChangeNotifications(addr, remoteEntry.getKey(), remoteEntry.getValue()); + if (logger.isTraceEnabled() && updatedStates.size() > 0) + { + for (Entry entry : updatedStates) + { + logger.trace("Updating {} state version to {} for {}", entry.getKey().toString(), entry.getValue().version, addr); + } + } + localState.addApplicationStates(updatedStates); + + for (Entry updatedEntry : updatedStates) + doOnChangeNotifications(addr, updatedEntry.getKey(), updatedEntry.getValue()); } // notify that a local application state is going to change (doesn't get triggered for remote changes) @@ -1447,7 +1462,7 @@ void examineGossiper(List gDigestList, List deltaGos if (gDigestList.size() == 0) { /* we've been sent a *completely* empty syn, which should normally never happen since an endpoint will at least send a syn with itself. - If this is happening then the node is attempting shadow gossip, and we should reply with everything we know. + If this is happening then the node is attempting shadow gossip, and we should respond with everything we know. */ logger.debug("Shadow request received, adding all states"); for (Map.Entry entry : endpointStateMap.entrySet()) @@ -1582,9 +1597,7 @@ public synchronized Map doShadowRound(Set message = new MessageOut(MessagingService.Verb.GOSSIP_DIGEST_SYN, - digestSynMessage, - GossipDigestSyn.serializer); + Message message = Message.out(GOSSIP_DIGEST_SYN, digestSynMessage); inShadowRound = true; boolean includePeers = false; @@ -1598,14 +1611,14 @@ public synchronized Map doShadowRound(Set nodes) return true; } + + @VisibleForTesting + public void stopShutdownAndWait(long timeout, TimeUnit unit) throws InterruptedException, TimeoutException + { + stop(); + ExecutorUtils.shutdownAndWait(timeout, unit, executor); + } } diff --git a/src/java/org/apache/cassandra/hints/EncodedHintMessage.java b/src/java/org/apache/cassandra/hints/EncodedHintMessage.java deleted file mode 100644 index 50d130262248..000000000000 --- a/src/java/org/apache/cassandra/hints/EncodedHintMessage.java +++ /dev/null @@ -1,94 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.cassandra.hints; - -import java.io.IOException; -import java.nio.ByteBuffer; -import java.util.UUID; - -import org.apache.cassandra.db.TypeSizes; -import org.apache.cassandra.io.IVersionedSerializer; -import org.apache.cassandra.io.util.DataInputPlus; -import org.apache.cassandra.io.util.DataOutputPlus; -import org.apache.cassandra.net.MessageOut; -import org.apache.cassandra.net.MessagingService; -import org.apache.cassandra.utils.UUIDSerializer; - -/** - * A specialized version of {@link HintMessage} that takes an already encoded in a bytebuffer hint and sends it verbatim. - * - * An optimization for when dispatching a hint file of the current messaging version to a node of the same messaging version, - * which is the most common case. Saves on extra ByteBuffer allocations one redundant hint deserialization-serialization cycle. - * - * Never deserialized as an EncodedHintMessage - the receiving side will always deserialize the message as vanilla - * {@link HintMessage}. - */ -final class EncodedHintMessage -{ - private static final IVersionedSerializer serializer = new Serializer(); - - private final UUID hostId; - private final ByteBuffer hint; - private final int version; - - EncodedHintMessage(UUID hostId, ByteBuffer hint, int version) - { - this.hostId = hostId; - this.hint = hint; - this.version = version; - } - - MessageOut createMessageOut() - { - return new MessageOut<>(MessagingService.Verb.HINT, this, serializer); - } - - public long getHintCreationTime() - { - return Hint.serializer.getHintCreationTime(hint, version); - } - - private static class Serializer implements IVersionedSerializer - { - public long serializedSize(EncodedHintMessage message, int version) - { - if (version != message.version) - throw new IllegalArgumentException("serializedSize() called with non-matching version " + version); - - long size = UUIDSerializer.serializer.serializedSize(message.hostId, version); - size += TypeSizes.sizeofUnsignedVInt(message.hint.remaining()); - size += message.hint.remaining(); - return size; - } - - public void serialize(EncodedHintMessage message, DataOutputPlus out, int version) throws IOException - { - if (version != message.version) - throw new IllegalArgumentException("serialize() called with non-matching version " + version); - - UUIDSerializer.serializer.serialize(message.hostId, out, version); - out.writeUnsignedVInt(message.hint.remaining()); - out.write(message.hint); - } - - public EncodedHintMessage deserialize(DataInputPlus in, int version) throws IOException - { - throw new UnsupportedOperationException(); - } - } -} diff --git a/src/java/org/apache/cassandra/hints/HintMessage.java b/src/java/org/apache/cassandra/hints/HintMessage.java index 683b894cb6a7..333af842dc09 100644 --- a/src/java/org/apache/cassandra/hints/HintMessage.java +++ b/src/java/org/apache/cassandra/hints/HintMessage.java @@ -19,6 +19,7 @@ package org.apache.cassandra.hints; import java.io.IOException; +import java.nio.ByteBuffer; import java.util.Objects; import java.util.UUID; @@ -28,11 +29,9 @@ import org.apache.cassandra.db.TypeSizes; import org.apache.cassandra.exceptions.UnknownTableException; -import org.apache.cassandra.io.IVersionedSerializer; +import org.apache.cassandra.io.IVersionedAsymmetricSerializer; import org.apache.cassandra.io.util.DataInputPlus; import org.apache.cassandra.io.util.DataOutputPlus; -import org.apache.cassandra.net.MessageOut; -import org.apache.cassandra.net.MessagingService; import org.apache.cassandra.io.util.TrackedDataInputPlus; import org.apache.cassandra.schema.TableId; import org.apache.cassandra.utils.UUIDSerializer; @@ -49,9 +48,9 @@ * Scenario (2) means that we got a hint from a node that's going through decommissioning and is streaming its hints * elsewhere first. */ -public final class HintMessage +public final class HintMessage implements SerializableHintMessage { - public static final IVersionedSerializer serializer = new Serializer(); + public static final IVersionedAsymmetricSerializer serializer = new Serializer(); final UUID hostId; @@ -75,37 +74,72 @@ public final class HintMessage this.unknownTableID = unknownTableID; } - public MessageOut createMessageOut() + public static class Serializer implements IVersionedAsymmetricSerializer { - return new MessageOut<>(MessagingService.Verb.HINT, this, serializer); - } - - public static class Serializer implements IVersionedSerializer - { - public long serializedSize(HintMessage message, int version) + public long serializedSize(SerializableHintMessage obj, int version) { - long size = UUIDSerializer.serializer.serializedSize(message.hostId, version); + if (obj instanceof HintMessage) + { + HintMessage message = (HintMessage) obj; + long size = UUIDSerializer.serializer.serializedSize(message.hostId, version); + + long hintSize = Hint.serializer.serializedSize(message.hint, version); + size += TypeSizes.sizeofUnsignedVInt(hintSize); + size += hintSize; + + return size; + } + else if (obj instanceof Encoded) + { + Encoded message = (Encoded) obj; - long hintSize = Hint.serializer.serializedSize(message.hint, version); - size += TypeSizes.sizeofUnsignedVInt(hintSize); - size += hintSize; + if (version != message.version) + throw new IllegalArgumentException("serializedSize() called with non-matching version " + version); - return size; + long size = UUIDSerializer.serializer.serializedSize(message.hostId, version); + size += TypeSizes.sizeofUnsignedVInt(message.hint.remaining()); + size += message.hint.remaining(); + return size; + } + else + { + throw new IllegalStateException("Unexpected type: " + obj); + } } - public void serialize(HintMessage message, DataOutputPlus out, int version) throws IOException + public void serialize(SerializableHintMessage obj, DataOutputPlus out, int version) throws IOException { - Objects.requireNonNull(message.hint); // we should never *send* a HintMessage with null hint + if (obj instanceof HintMessage) + { + HintMessage message = (HintMessage) obj; - UUIDSerializer.serializer.serialize(message.hostId, out, version); + Objects.requireNonNull(message.hint); // we should never *send* a HintMessage with null hint - /* - * We are serializing the hint size so that the receiver of the message could gracefully handle - * deserialize failure when a table had been dropped, by simply skipping the unread bytes. - */ - out.writeUnsignedVInt(Hint.serializer.serializedSize(message.hint, version)); + UUIDSerializer.serializer.serialize(message.hostId, out, version); - Hint.serializer.serialize(message.hint, out, version); + /* + * We are serializing the hint size so that the receiver of the message could gracefully handle + * deserialize failure when a table had been dropped, by simply skipping the unread bytes. + */ + out.writeUnsignedVInt(Hint.serializer.serializedSize(message.hint, version)); + + Hint.serializer.serialize(message.hint, out, version); + } + else if (obj instanceof Encoded) + { + Encoded message = (Encoded) obj; + + if (version != message.version) + throw new IllegalArgumentException("serialize() called with non-matching version " + version); + + UUIDSerializer.serializer.serialize(message.hostId, out, version); + out.writeUnsignedVInt(message.hint.remaining()); + out.write(message.hint); + } + else + { + throw new IllegalStateException("Unexpected type: " + obj); + } } /* @@ -130,4 +164,32 @@ public HintMessage deserialize(DataInputPlus in, int version) throws IOException } } } + + /** + * A specialized version of {@link HintMessage} that takes an already encoded in a bytebuffer hint and sends it verbatim. + * + * An optimization for when dispatching a hint file of the current messaging version to a node of the same messaging version, + * which is the most common case. Saves on extra ByteBuffer allocations one redundant hint deserialization-serialization cycle. + * + * Never deserialized as an HintMessage.Encoded - the receiving side will always deserialize the message as vanilla + * {@link HintMessage}. + */ + static final class Encoded implements SerializableHintMessage + { + private final UUID hostId; + private final ByteBuffer hint; + private final int version; + + Encoded(UUID hostId, ByteBuffer hint, int version) + { + this.hostId = hostId; + this.hint = hint; + this.version = version; + } + + public long getHintCreationTime() + { + return Hint.serializer.getHintCreationTime(hint, version); + } + } } diff --git a/src/java/org/apache/cassandra/hints/HintResponse.java b/src/java/org/apache/cassandra/hints/HintResponse.java deleted file mode 100644 index 8aa888f55154..000000000000 --- a/src/java/org/apache/cassandra/hints/HintResponse.java +++ /dev/null @@ -1,58 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package org.apache.cassandra.hints; - -import org.apache.cassandra.io.IVersionedSerializer; -import org.apache.cassandra.io.util.DataInputPlus; -import org.apache.cassandra.io.util.DataOutputPlus; -import org.apache.cassandra.net.MessageOut; -import org.apache.cassandra.net.MessagingService; - -/** - * An empty successful response to a HintMessage. - */ -public final class HintResponse -{ - public static final IVersionedSerializer serializer = new Serializer(); - - static final HintResponse instance = new HintResponse(); - static final MessageOut message = - new MessageOut<>(MessagingService.Verb.REQUEST_RESPONSE, instance, serializer); - - private HintResponse() - { - } - - private static final class Serializer implements IVersionedSerializer - { - public long serializedSize(HintResponse response, int version) - { - return 0; - } - - public void serialize(HintResponse response, DataOutputPlus out, int version) - { - } - - public HintResponse deserialize(DataInputPlus in, int version) - { - return instance; - } - } -} diff --git a/src/java/org/apache/cassandra/hints/HintVerbHandler.java b/src/java/org/apache/cassandra/hints/HintVerbHandler.java index cec6f0b2b448..2fbe4754b62a 100644 --- a/src/java/org/apache/cassandra/hints/HintVerbHandler.java +++ b/src/java/org/apache/cassandra/hints/HintVerbHandler.java @@ -26,7 +26,7 @@ import org.apache.cassandra.db.partitions.PartitionUpdate; import org.apache.cassandra.locator.InetAddressAndPort; import org.apache.cassandra.net.IVerbHandler; -import org.apache.cassandra.net.MessageIn; +import org.apache.cassandra.net.Message; import org.apache.cassandra.net.MessagingService; import org.apache.cassandra.serializers.MarshalException; import org.apache.cassandra.service.StorageProxy; @@ -41,9 +41,11 @@ */ public final class HintVerbHandler implements IVerbHandler { + public static final HintVerbHandler instance = new HintVerbHandler(); + private static final Logger logger = LoggerFactory.getLogger(HintVerbHandler.class); - public void doVerb(MessageIn message, int id) + public void doVerb(Message message) { UUID hostId = message.payload.hostId; Hint hint = message.payload.hint; @@ -59,7 +61,7 @@ public void doVerb(MessageIn message, int id) address, hostId, message.payload.unknownTableID); - reply(id, message.from); + respond(message); return; } @@ -71,7 +73,7 @@ public void doVerb(MessageIn message, int id) catch (MarshalException e) { logger.warn("Failed to validate a hint for {}: {} - skipped", address, hostId); - reply(id, message.from); + respond(message); return; } @@ -80,24 +82,24 @@ public void doVerb(MessageIn message, int id) // the node is not the final destination of the hint (must have gotten it from a decommissioning node), // so just store it locally, to be delivered later. HintsService.instance.write(hostId, hint); - reply(id, message.from); + respond(message); } else if (!StorageProxy.instance.appliesLocally(hint.mutation)) { // the topology has changed, and we are no longer a replica of the mutation - since we don't know which node(s) // it has been handed over to, re-address the hint to all replicas; see CASSANDRA-5902. HintsService.instance.writeForAllReplicas(hint); - reply(id, message.from); + respond(message); } else { // the common path - the node is both the destination and a valid replica for the hint. - hint.applyFuture().thenAccept(o -> reply(id, message.from)).exceptionally(e -> {logger.debug("Failed to apply hint", e); return null;}); + hint.applyFuture().thenAccept(o -> respond(message)).exceptionally(e -> {logger.debug("Failed to apply hint", e); return null;}); } } - private static void reply(int id, InetAddressAndPort to) + private static void respond(Message respondTo) { - MessagingService.instance().sendReply(HintResponse.message, id, to); + MessagingService.instance().send(respondTo.emptyResponse(), respondTo.from()); } } diff --git a/src/java/org/apache/cassandra/hints/HintsCatalog.java b/src/java/org/apache/cassandra/hints/HintsCatalog.java index 7d5c8e6006ae..5a92889ff46a 100644 --- a/src/java/org/apache/cassandra/hints/HintsCatalog.java +++ b/src/java/org/apache/cassandra/hints/HintsCatalog.java @@ -20,6 +20,7 @@ import java.io.File; import java.io.IOException; import java.nio.file.Files; +import java.nio.file.Path; import java.util.*; import java.util.concurrent.ConcurrentHashMap; import java.util.stream.Stream; @@ -64,10 +65,10 @@ private HintsCatalog(File hintsDirectory, ImmutableMap writerPar */ static HintsCatalog load(File hintsDirectory, ImmutableMap writerParams) { - try + try(Stream list = Files.list(hintsDirectory.toPath())) { Map> stores = - Files.list(hintsDirectory.toPath()) + list .filter(HintsDescriptor::isHintFileName) .map(HintsDescriptor::readFromFileQuietly) .filter(Optional::isPresent) diff --git a/src/java/org/apache/cassandra/hints/HintsDispatcher.java b/src/java/org/apache/cassandra/hints/HintsDispatcher.java index 2cff18608c0d..39e4b25c0b22 100644 --- a/src/java/org/apache/cassandra/hints/HintsDispatcher.java +++ b/src/java/org/apache/cassandra/hints/HintsDispatcher.java @@ -20,7 +20,6 @@ import java.io.File; import java.nio.ByteBuffer; import java.util.*; -import java.util.concurrent.TimeUnit; import java.util.function.BooleanSupplier; import java.util.function.Function; @@ -28,19 +27,21 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import org.apache.cassandra.db.monitoring.ApproximateTime; +import org.apache.cassandra.net.RequestCallback; import org.apache.cassandra.exceptions.RequestFailureReason; import org.apache.cassandra.locator.InetAddressAndPort; import org.apache.cassandra.metrics.HintsServiceMetrics; -import org.apache.cassandra.net.IAsyncCallbackWithFailure; -import org.apache.cassandra.net.MessageIn; +import org.apache.cassandra.net.Message; import org.apache.cassandra.net.MessagingService; import org.apache.cassandra.utils.concurrent.SimpleCondition; +import static org.apache.cassandra.net.Verb.HINT_REQ; +import static org.apache.cassandra.utils.MonotonicClock.approxTime; + /** * Dispatches a single hints file to a specified node in a batched manner. * - * Uses either {@link EncodedHintMessage} - when dispatching hints into a node with the same messaging version as the hints file, + * Uses either {@link HintMessage.Encoded} - when dispatching hints into a node with the same messaging version as the hints file, * or {@link HintMessage}, when conversion is required. */ final class HintsDispatcher implements AutoCloseable @@ -70,7 +71,7 @@ private HintsDispatcher(HintsReader reader, UUID hostId, InetAddressAndPort addr static HintsDispatcher create(File file, RateLimiter rateLimiter, InetAddressAndPort address, UUID hostId, BooleanSupplier abortRequested) { - int messagingVersion = MessagingService.instance().getVersion(address); + int messagingVersion = MessagingService.instance().versions.get(address); HintsDispatcher dispatcher = new HintsDispatcher(HintsReader.open(file, rateLimiter), hostId, address, messagingVersion, abortRequested); HintDiagnostics.dispatcherCreated(dispatcher); return dispatcher; @@ -187,8 +188,8 @@ private Action sendHints(Iterator hints, Collection callbacks, private Callback sendHint(Hint hint) { Callback callback = new Callback(hint.creationTime); - HintMessage message = new HintMessage(hostId, hint); - MessagingService.instance().sendRRWithFailure(message.createMessageOut(), address, callback); + Message message = Message.out(HINT_REQ, new HintMessage(hostId, hint)); + MessagingService.instance().sendWithCallback(message, address, callback); return callback; } @@ -198,34 +199,32 @@ private Callback sendHint(Hint hint) private Callback sendEncodedHint(ByteBuffer hint) { - EncodedHintMessage message = new EncodedHintMessage(hostId, hint, messagingVersion); + HintMessage.Encoded message = new HintMessage.Encoded(hostId, hint, messagingVersion); Callback callback = new Callback(message.getHintCreationTime()); - MessagingService.instance().sendRRWithFailure(message.createMessageOut(), address, callback); + MessagingService.instance().sendWithCallback(Message.out(HINT_REQ, message), address, callback); return callback; } - private static final class Callback implements IAsyncCallbackWithFailure + private static final class Callback implements RequestCallback { enum Outcome { SUCCESS, TIMEOUT, FAILURE, INTERRUPTED } - private final long start = System.nanoTime(); + private final long start = approxTime.now(); private final SimpleCondition condition = new SimpleCondition(); private volatile Outcome outcome; - private final long hintCreationTime; + private final long hintCreationNanoTime; - private Callback(long hintCreationTime) + private Callback(long hintCreationTimeMillisSinceEpoch) { - this.hintCreationTime = hintCreationTime; + this.hintCreationNanoTime = approxTime.translate().fromMillisSinceEpoch(hintCreationTimeMillisSinceEpoch); } Outcome await() { - long timeout = TimeUnit.MILLISECONDS.toNanos(MessagingService.Verb.HINT.getTimeout()) - (System.nanoTime() - start); boolean timedOut; - try { - timedOut = !condition.await(timeout, TimeUnit.NANOSECONDS); + timedOut = !condition.awaitUntil(HINT_REQ.expiresAtNanos(start)); } catch (InterruptedException e) { @@ -236,24 +235,27 @@ Outcome await() return timedOut ? Outcome.TIMEOUT : outcome; } + @Override + public boolean invokeOnFailure() + { + return true; + } + + @Override public void onFailure(InetAddressAndPort from, RequestFailureReason failureReason) { outcome = Outcome.FAILURE; condition.signalAll(); } - public void response(MessageIn msg) + @Override + public void onResponse(Message msg) { - HintsServiceMetrics.updateDelayMetrics(msg.from, ApproximateTime.currentTimeMillis() - this.hintCreationTime); + HintsServiceMetrics.updateDelayMetrics(msg.from(), approxTime.now() - this.hintCreationNanoTime); outcome = Outcome.SUCCESS; condition.signalAll(); } - public boolean isLatencyForSnitch() - { - return false; - } - @Override public boolean supportsBackPressure() { diff --git a/src/java/org/apache/cassandra/locator/ILatencySubscriber.java b/src/java/org/apache/cassandra/hints/SerializableHintMessage.java similarity index 84% rename from src/java/org/apache/cassandra/locator/ILatencySubscriber.java rename to src/java/org/apache/cassandra/hints/SerializableHintMessage.java index f6c1c7f20f88..43c289c7518e 100644 --- a/src/java/org/apache/cassandra/locator/ILatencySubscriber.java +++ b/src/java/org/apache/cassandra/hints/SerializableHintMessage.java @@ -15,9 +15,9 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.cassandra.locator; -public interface ILatencySubscriber +package org.apache.cassandra.hints; + +public interface SerializableHintMessage { - public void receiveTiming(InetAddressAndPort address, long latency); } diff --git a/src/java/org/apache/cassandra/index/SecondaryIndexManager.java b/src/java/org/apache/cassandra/index/SecondaryIndexManager.java index ec54a652a5f2..779ea7c68857 100644 --- a/src/java/org/apache/cassandra/index/SecondaryIndexManager.java +++ b/src/java/org/apache/cassandra/index/SecondaryIndexManager.java @@ -73,9 +73,11 @@ import org.apache.cassandra.transport.ProtocolVersion; import org.apache.cassandra.utils.FBUtilities; import org.apache.cassandra.utils.JVMStabilityInspector; -import org.apache.cassandra.utils.concurrent.OpOrder; import org.apache.cassandra.utils.concurrent.Refs; +import static org.apache.cassandra.utils.ExecutorUtils.awaitTermination; +import static org.apache.cassandra.utils.ExecutorUtils.shutdown; + /** * Handles the core maintenance functionality associated with indexes: adding/removing them to or from * a table, (re)building during bootstrap or other streaming operations, flushing, reloading metadata @@ -786,7 +788,7 @@ private void flushIndexesBlocking(Set indexes, FutureCallback cal { indexes.forEach(index -> index.getBackingTable() - .map(cfs -> wait.add(cfs.forceFlush())) + .map(cfs -> wait.add(cfs.forceFlushToSSTable())) .orElseGet(() -> nonCfsIndexes.add(index))); } @@ -925,7 +927,7 @@ public int calculateIndexingPageSize() if (meanPartitionSize <= 0) return DEFAULT_PAGE_SIZE; - int meanCellsPerPartition = baseCfs.getMeanColumns(); + int meanCellsPerPartition = baseCfs.getMeanEstimatedCellPerPartitionCount(); if (meanCellsPerPartition <= 0) return DEFAULT_PAGE_SIZE; @@ -1487,12 +1489,9 @@ public void handleNotification(INotification notification, Object sender) } @VisibleForTesting - public static void shutdownExecutors() throws InterruptedException + public static void shutdownAndWait(long timeout, TimeUnit units) throws InterruptedException, TimeoutException { - ExecutorService[] executors = new ExecutorService[]{ asyncExecutor, blockingExecutor }; - for (ExecutorService executor : executors) - executor.shutdown(); - for (ExecutorService executor : executors) - executor.awaitTermination(60, TimeUnit.SECONDS); + shutdown(asyncExecutor, blockingExecutor); + awaitTermination(timeout, units, asyncExecutor, blockingExecutor); } } diff --git a/src/java/org/apache/cassandra/index/internal/CassandraIndex.java b/src/java/org/apache/cassandra/index/internal/CassandraIndex.java index ecd25fd89a6b..d1feb7917d1a 100644 --- a/src/java/org/apache/cassandra/index/internal/CassandraIndex.java +++ b/src/java/org/apache/cassandra/index/internal/CassandraIndex.java @@ -61,7 +61,6 @@ import org.apache.cassandra.schema.IndexMetadata; import org.apache.cassandra.utils.FBUtilities; import org.apache.cassandra.utils.Pair; -import org.apache.cassandra.utils.concurrent.OpOrder; import org.apache.cassandra.utils.concurrent.Refs; import static org.apache.cassandra.cql3.statements.RequestValidations.checkFalse; @@ -186,7 +185,7 @@ public Optional getBackingTable() public Callable getBlockingFlushTask() { return () -> { - indexCfs.forceBlockingFlush(); + indexCfs.forceBlockingFlushToSSTable(); return null; }; } @@ -270,7 +269,7 @@ public AbstractType customExpressionValueType() public long getEstimatedResultRows() { - return indexCfs.getMeanColumns(); + return indexCfs.getMeanRowCount(); } /** @@ -663,7 +662,7 @@ private void invalidate() CompactionManager.instance.interruptCompactionForCFs(cfss, (sstable) -> true, true); CompactionManager.instance.waitForCessation(cfss, (sstable) -> true); Keyspace.writeOrder.awaitNewBarrier(); - indexCfs.forceBlockingFlush(); + indexCfs.forceBlockingFlushToSSTable(); indexCfs.readOrdering.awaitNewBarrier(); indexCfs.invalidate(); } @@ -689,7 +688,7 @@ private Callable getBuildIndexTask() @SuppressWarnings("resource") private void buildBlocking() { - baseCfs.forceBlockingFlush(); + baseCfs.forceBlockingFlushToSSTable(); try (ColumnFamilyStore.RefViewFragment viewFragment = baseCfs.selectAndReference(View.selectFunction(SSTableSet.CANONICAL)); Refs sstables = viewFragment.refs) @@ -713,7 +712,7 @@ private void buildBlocking() ImmutableSet.copyOf(sstables)); Future future = CompactionManager.instance.submitIndexBuild(builder); FBUtilities.waitOnFuture(future); - indexCfs.forceBlockingFlush(); + indexCfs.forceBlockingFlushToSSTable(); } logger.info("Index build of {} complete", metadata.name); } diff --git a/src/java/org/apache/cassandra/index/sasi/SASIIndex.java b/src/java/org/apache/cassandra/index/sasi/SASIIndex.java index 19c09cc51624..07327ea6d6ca 100644 --- a/src/java/org/apache/cassandra/index/sasi/SASIIndex.java +++ b/src/java/org/apache/cassandra/index/sasi/SASIIndex.java @@ -19,6 +19,7 @@ import java.util.*; import java.util.concurrent.Callable; +import java.util.concurrent.TimeUnit; import java.util.function.BiFunction; import com.googlecode.concurrenttrees.common.Iterables; @@ -60,6 +61,8 @@ import org.apache.cassandra.utils.Pair; import org.apache.cassandra.utils.concurrent.OpOrder; +import static java.util.concurrent.TimeUnit.MILLISECONDS; + public class SASIIndex implements Index, INotificationConsumer { public final static String USAGE_WARNING = "SASI indexes are experimental and are not recommended for production use."; @@ -295,7 +298,7 @@ public Searcher searcherFor(ReadCommand command) throws InvalidRequestException { TableMetadata config = command.metadata(); ColumnFamilyStore cfs = Schema.instance.getColumnFamilyStoreInstance(config.id); - return controller -> new QueryPlan(cfs, command, DatabaseDescriptor.getRangeRpcTimeout()).execute(controller); + return controller -> new QueryPlan(cfs, command, DatabaseDescriptor.getRangeRpcTimeout(MILLISECONDS)).execute(controller); } public SSTableFlushObserver getFlushObserver(Descriptor descriptor, OperationType opType) diff --git a/src/java/org/apache/cassandra/io/DummyByteVersionedSerializer.java b/src/java/org/apache/cassandra/io/DummyByteVersionedSerializer.java deleted file mode 100644 index d82ff7d7da07..000000000000 --- a/src/java/org/apache/cassandra/io/DummyByteVersionedSerializer.java +++ /dev/null @@ -1,55 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.cassandra.io; - -import java.io.IOException; - -import com.google.common.base.Preconditions; - -import org.apache.cassandra.io.util.DataInputPlus; -import org.apache.cassandra.io.util.DataOutputPlus; -import org.apache.cassandra.net.MessagingService; - -/** - * Serializes a dummy byte that can't be set. Will always write 0 and return 0 in a correctly formed message. - */ -public class DummyByteVersionedSerializer implements IVersionedSerializer -{ - public static final DummyByteVersionedSerializer instance = new DummyByteVersionedSerializer(); - - private DummyByteVersionedSerializer() {} - - public void serialize(byte[] bytes, DataOutputPlus out, int version) throws IOException - { - Preconditions.checkArgument(bytes == MessagingService.ONE_BYTE); - out.write(0); - } - - public byte[] deserialize(DataInputPlus in, int version) throws IOException - { - assert(0 == in.readByte()); - return MessagingService.ONE_BYTE; - } - - public long serializedSize(byte[] bytes, int version) - { - //Payload - return 1; - } -} diff --git a/src/java/org/apache/cassandra/io/IVersionedAsymmetricSerializer.java b/src/java/org/apache/cassandra/io/IVersionedAsymmetricSerializer.java new file mode 100644 index 000000000000..8ad2c285c326 --- /dev/null +++ b/src/java/org/apache/cassandra/io/IVersionedAsymmetricSerializer.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.cassandra.io; + +import java.io.IOException; + +import org.apache.cassandra.io.util.DataInputPlus; +import org.apache.cassandra.io.util.DataOutputPlus; + +public interface IVersionedAsymmetricSerializer +{ + /** + * Serialize the specified type into the specified DataOutputStream instance. + * + * @param t type that needs to be serialized + * @param out DataOutput into which serialization needs to happen. + * @param version protocol version + * @throws IOException if serialization fails + */ + public void serialize(In t, DataOutputPlus out, int version) throws IOException; + + /** + * Deserialize into the specified DataInputStream instance. + * @param in DataInput from which deserialization needs to happen. + * @param version protocol version + * @return the type that was deserialized + * @throws IOException if deserialization fails + */ + public Out deserialize(DataInputPlus in, int version) throws IOException; + + /** + * Calculate serialized size of object without actually serializing. + * @param t object to calculate serialized size + * @param version protocol version + * @return serialized size of object t + */ + public long serializedSize(In t, int version); +} diff --git a/src/java/org/apache/cassandra/io/IVersionedSerializer.java b/src/java/org/apache/cassandra/io/IVersionedSerializer.java index e5555735079d..6730ec08249e 100644 --- a/src/java/org/apache/cassandra/io/IVersionedSerializer.java +++ b/src/java/org/apache/cassandra/io/IVersionedSerializer.java @@ -17,37 +17,6 @@ */ package org.apache.cassandra.io; -import java.io.IOException; - -import org.apache.cassandra.io.util.DataInputPlus; -import org.apache.cassandra.io.util.DataOutputPlus; - -public interface IVersionedSerializer +public interface IVersionedSerializer extends IVersionedAsymmetricSerializer { - /** - * Serialize the specified type into the specified DataOutputStream instance. - * - * @param t type that needs to be serialized - * @param out DataOutput into which serialization needs to happen. - * @param version protocol version - * @throws java.io.IOException if serialization fails - */ - public void serialize(T t, DataOutputPlus out, int version) throws IOException; - - /** - * Deserialize into the specified DataInputStream instance. - * @param in DataInput from which deserialization needs to happen. - * @param version protocol version - * @return the type that was deserialized - * @throws IOException if deserialization fails - */ - public T deserialize(DataInputPlus in, int version) throws IOException; - - /** - * Calculate serialized size of object without actually serializing. - * @param t object to calculate serialized size - * @param version protocol version - * @return serialized size of object t - */ - public long serializedSize(T t, int version); } diff --git a/src/java/org/apache/cassandra/io/sstable/IndexSummaryManager.java b/src/java/org/apache/cassandra/io/sstable/IndexSummaryManager.java index 2d58cf8e2174..1f4059aa8a61 100644 --- a/src/java/org/apache/cassandra/io/sstable/IndexSummaryManager.java +++ b/src/java/org/apache/cassandra/io/sstable/IndexSummaryManager.java @@ -25,6 +25,7 @@ import java.util.Set; import java.util.concurrent.ScheduledFuture; import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableSet; @@ -44,11 +45,15 @@ import org.apache.cassandra.db.lifecycle.View; import org.apache.cassandra.io.sstable.format.SSTableReader; import org.apache.cassandra.schema.TableId; +import org.apache.cassandra.utils.ExecutorUtils; import org.apache.cassandra.utils.FBUtilities; import org.apache.cassandra.utils.MBeanWrapper; import org.apache.cassandra.utils.Pair; import org.apache.cassandra.utils.WrappedRunnable; +import static org.apache.cassandra.utils.ExecutorUtils.awaitTermination; +import static org.apache.cassandra.utils.ExecutorUtils.shutdown; + /** * Manages the fixed-size memory pool for index summaries, periodically resizing them * in order to give more memory to hot sstables and less memory to cold sstables. @@ -264,4 +269,10 @@ public static List redistributeSummaries(IndexSummaryRedistributi { return CompactionManager.instance.runIndexSummaryRedistribution(redistribution); } + + @VisibleForTesting + public void shutdownAndWait(long timeout, TimeUnit unit) throws InterruptedException, TimeoutException + { + ExecutorUtils.shutdownAndWait(timeout, unit, executor); + } } diff --git a/src/java/org/apache/cassandra/io/sstable/format/SSTableReader.java b/src/java/org/apache/cassandra/io/sstable/format/SSTableReader.java index e3059c873f6d..36a1e633080d 100644 --- a/src/java/org/apache/cassandra/io/sstable/format/SSTableReader.java +++ b/src/java/org/apache/cassandra/io/sstable/format/SSTableReader.java @@ -56,6 +56,7 @@ import org.apache.cassandra.dht.Bounds; import org.apache.cassandra.dht.Range; import org.apache.cassandra.dht.Token; +import org.apache.cassandra.exceptions.UnknownColumnException; import org.apache.cassandra.io.FSError; import org.apache.cassandra.io.compress.CompressionMetadata; import org.apache.cassandra.io.sstable.*; @@ -431,13 +432,22 @@ public static SSTableReader openForBatch(Descriptor descriptor, Set c long fileLength = new File(descriptor.filenameFor(Component.DATA)).length(); if (logger.isDebugEnabled()) logger.debug("Opening {} ({})", descriptor, FBUtilities.prettyPrintMemory(fileLength)); - SSTableReader sstable = internalOpen(descriptor, - components, - metadata, - System.currentTimeMillis(), - statsMetadata, - OpenReason.NORMAL, - header.toHeader(metadata.get())); + + final SSTableReader sstable; + try + { + sstable = internalOpen(descriptor, + components, + metadata, + System.currentTimeMillis(), + statsMetadata, + OpenReason.NORMAL, + header.toHeader(metadata.get())); + } + catch (UnknownColumnException e) + { + throw new IllegalStateException(e); + } try(FileHandle.Builder ibuilder = new FileHandle.Builder(sstable.descriptor.filenameFor(Component.PRIMARY_INDEX)) .mmapped(DatabaseDescriptor.getIndexAccessMode() == Config.DiskAccessMode.mmap) @@ -522,13 +532,22 @@ public static SSTableReader open(Descriptor descriptor, long fileLength = new File(descriptor.filenameFor(Component.DATA)).length(); if (logger.isDebugEnabled()) logger.debug("Opening {} ({})", descriptor, FBUtilities.prettyPrintMemory(fileLength)); - SSTableReader sstable = internalOpen(descriptor, - components, - metadata, - System.currentTimeMillis(), - statsMetadata, - OpenReason.NORMAL, - header.toHeader(metadata.get())); + + final SSTableReader sstable; + try + { + sstable = internalOpen(descriptor, + components, + metadata, + System.currentTimeMillis(), + statsMetadata, + OpenReason.NORMAL, + header.toHeader(metadata.get())); + } + catch (UnknownColumnException e) + { + throw new IllegalStateException(e); + } try { @@ -1924,9 +1943,9 @@ public EstimatedHistogram getEstimatedPartitionSize() return sstableMetadata.estimatedPartitionSize; } - public EstimatedHistogram getEstimatedColumnCount() + public EstimatedHistogram getEstimatedCellPerPartitionCount() { - return sstableMetadata.estimatedColumnCount; + return sstableMetadata.estimatedCellPerPartitionCount; } public double getEstimatedDroppableTombstoneRatio(int gcBefore) @@ -2475,13 +2494,10 @@ public static SSTableReader moveAndOpenSSTable(ColumnFamilyStore cfs, Descriptor return reader; } - public static void shutdownBlocking() throws InterruptedException + public static void shutdownBlocking(long timeout, TimeUnit unit) throws InterruptedException, TimeoutException { - if (syncExecutor != null) - { - syncExecutor.shutdownNow(); - syncExecutor.awaitTermination(0, TimeUnit.SECONDS); - } + + ExecutorUtils.shutdownNowAndWait(timeout, unit, syncExecutor); resetTidying(); } } diff --git a/src/java/org/apache/cassandra/io/sstable/format/big/BigTableZeroCopyWriter.java b/src/java/org/apache/cassandra/io/sstable/format/big/BigTableZeroCopyWriter.java index 882638121305..f05ea94cb7ea 100644 --- a/src/java/org/apache/cassandra/io/sstable/format/big/BigTableZeroCopyWriter.java +++ b/src/java/org/apache/cassandra/io/sstable/format/big/BigTableZeroCopyWriter.java @@ -17,6 +17,7 @@ */ package org.apache.cassandra.io.sstable.format.big; +import java.io.EOFException; import java.io.File; import java.io.IOException; import java.util.Collection; @@ -31,7 +32,6 @@ import org.slf4j.LoggerFactory; import org.apache.cassandra.config.DatabaseDescriptor; -import org.apache.cassandra.db.lifecycle.LifecycleTransaction; import org.apache.cassandra.db.rows.UnfilteredRowIterator; import org.apache.cassandra.io.FSWriteError; import org.apache.cassandra.io.compress.BufferType; @@ -43,7 +43,7 @@ import org.apache.cassandra.io.util.DataInputPlus; import org.apache.cassandra.io.util.SequentialWriter; import org.apache.cassandra.io.util.SequentialWriterOption; -import org.apache.cassandra.net.async.RebufferingByteBufDataInputPlus; +import org.apache.cassandra.net.AsyncStreamingInputPlus; import org.apache.cassandra.schema.TableId; import org.apache.cassandra.schema.TableMetadataRef; @@ -202,22 +202,25 @@ public void writeComponent(Component.Type type, DataInputPlus in, long size) { logger.info("Writing component {} to {} length {}", type, componentWriters.get(type).getPath(), prettyPrintMemory(size)); - if (in instanceof RebufferingByteBufDataInputPlus) - write((RebufferingByteBufDataInputPlus) in, size, componentWriters.get(type)); + if (in instanceof AsyncStreamingInputPlus) + write((AsyncStreamingInputPlus) in, size, componentWriters.get(type)); else write(in, size, componentWriters.get(type)); } - private void write(RebufferingByteBufDataInputPlus in, long size, SequentialWriter writer) + private void write(AsyncStreamingInputPlus in, long size, SequentialWriter writer) { logger.info("Block Writing component to {} length {}", writer.getPath(), prettyPrintMemory(size)); try { - long bytesWritten = in.consumeUntil(writer, size); - - if (bytesWritten != size) - throw new IOException(format("Failed to read correct number of bytes from channel %s", writer)); + in.consume(writer::writeDirectlyToChannel, size); + writer.sync(); + } + // FIXME: handle ACIP exceptions properly + catch (EOFException | AsyncStreamingInputPlus.InputTimeoutException e) + { + in.close(); } catch (IOException e) { diff --git a/src/java/org/apache/cassandra/io/sstable/metadata/StatsMetadata.java b/src/java/org/apache/cassandra/io/sstable/metadata/StatsMetadata.java index 5d464fef4cca..f4e5beba24b6 100755 --- a/src/java/org/apache/cassandra/io/sstable/metadata/StatsMetadata.java +++ b/src/java/org/apache/cassandra/io/sstable/metadata/StatsMetadata.java @@ -47,7 +47,7 @@ public class StatsMetadata extends MetadataComponent public static final ISerializer> commitLogPositionSetSerializer = IntervalSet.serializer(CommitLogPosition.serializer); public final EstimatedHistogram estimatedPartitionSize; - public final EstimatedHistogram estimatedColumnCount; + public final EstimatedHistogram estimatedCellPerPartitionCount; public final IntervalSet commitLogIntervals; public final long minTimestamp; public final long maxTimestamp; @@ -70,7 +70,7 @@ public class StatsMetadata extends MetadataComponent public final EncodingStats encodingStats; public StatsMetadata(EstimatedHistogram estimatedPartitionSize, - EstimatedHistogram estimatedColumnCount, + EstimatedHistogram estimatedCellPerPartitionCount, IntervalSet commitLogIntervals, long minTimestamp, long maxTimestamp, @@ -91,7 +91,7 @@ public StatsMetadata(EstimatedHistogram estimatedPartitionSize, boolean isTransient) { this.estimatedPartitionSize = estimatedPartitionSize; - this.estimatedColumnCount = estimatedColumnCount; + this.estimatedCellPerPartitionCount = estimatedCellPerPartitionCount; this.commitLogIntervals = commitLogIntervals; this.minTimestamp = minTimestamp; this.maxTimestamp = maxTimestamp; @@ -124,7 +124,7 @@ public MetadataType getType() */ public double getEstimatedDroppableTombstoneRatio(int gcBefore) { - long estimatedColumnCount = this.estimatedColumnCount.mean() * this.estimatedColumnCount.count(); + long estimatedColumnCount = this.estimatedCellPerPartitionCount.mean() * this.estimatedCellPerPartitionCount.count(); if (estimatedColumnCount > 0) { double droppable = getDroppableTombstonesBefore(gcBefore); @@ -145,7 +145,7 @@ public double getDroppableTombstonesBefore(int gcBefore) public StatsMetadata mutateLevel(int newLevel) { return new StatsMetadata(estimatedPartitionSize, - estimatedColumnCount, + estimatedCellPerPartitionCount, commitLogIntervals, minTimestamp, maxTimestamp, @@ -169,7 +169,7 @@ public StatsMetadata mutateLevel(int newLevel) public StatsMetadata mutateRepairedMetadata(long newRepairedAt, UUID newPendingRepair, boolean newIsTransient) { return new StatsMetadata(estimatedPartitionSize, - estimatedColumnCount, + estimatedCellPerPartitionCount, commitLogIntervals, minTimestamp, maxTimestamp, @@ -199,7 +199,7 @@ public boolean equals(Object o) StatsMetadata that = (StatsMetadata) o; return new EqualsBuilder() .append(estimatedPartitionSize, that.estimatedPartitionSize) - .append(estimatedColumnCount, that.estimatedColumnCount) + .append(estimatedCellPerPartitionCount, that.estimatedCellPerPartitionCount) .append(commitLogIntervals, that.commitLogIntervals) .append(minTimestamp, that.minTimestamp) .append(maxTimestamp, that.maxTimestamp) @@ -225,7 +225,7 @@ public int hashCode() { return new HashCodeBuilder() .append(estimatedPartitionSize) - .append(estimatedColumnCount) + .append(estimatedCellPerPartitionCount) .append(commitLogIntervals) .append(minTimestamp) .append(maxTimestamp) @@ -252,7 +252,7 @@ public int serializedSize(Version version, StatsMetadata component) throws IOExc { int size = 0; size += EstimatedHistogram.serializer.serializedSize(component.estimatedPartitionSize); - size += EstimatedHistogram.serializer.serializedSize(component.estimatedColumnCount); + size += EstimatedHistogram.serializer.serializedSize(component.estimatedCellPerPartitionCount); size += CommitLogPosition.serializer.serializedSize(component.commitLogIntervals.upperBound().orElse(CommitLogPosition.NONE)); size += 8 + 8 + 4 + 4 + 4 + 4 + 8 + 8; // mix/max timestamp(long), min/maxLocalDeletionTime(int), min/max TTL, compressionRatio(double), repairedAt (long) size += TombstoneHistogram.serializer.serializedSize(component.estimatedTombstoneDropTime); @@ -290,7 +290,7 @@ public int serializedSize(Version version, StatsMetadata component) throws IOExc public void serialize(Version version, StatsMetadata component, DataOutputPlus out) throws IOException { EstimatedHistogram.serializer.serialize(component.estimatedPartitionSize, out); - EstimatedHistogram.serializer.serialize(component.estimatedColumnCount, out); + EstimatedHistogram.serializer.serialize(component.estimatedCellPerPartitionCount, out); CommitLogPosition.serializer.serialize(component.commitLogIntervals.upperBound().orElse(CommitLogPosition.NONE), out); out.writeLong(component.minTimestamp); out.writeLong(component.maxTimestamp); diff --git a/src/java/org/apache/cassandra/io/util/BufferedDataOutputStreamPlus.java b/src/java/org/apache/cassandra/io/util/BufferedDataOutputStreamPlus.java index 56d88f7c2b3f..7d1e91d641e6 100644 --- a/src/java/org/apache/cassandra/io/util/BufferedDataOutputStreamPlus.java +++ b/src/java/org/apache/cassandra/io/util/BufferedDataOutputStreamPlus.java @@ -28,8 +28,8 @@ import net.nicoulaj.compilecommand.annotations.DontInline; import org.apache.cassandra.config.Config; +import org.apache.cassandra.utils.FastByteOperations; import org.apache.cassandra.utils.memory.MemoryUtil; -import org.apache.cassandra.utils.vint.VIntCoding; /** * An implementation of the DataOutputStreamPlus interface using a ByteBuffer to stage writes @@ -43,15 +43,6 @@ public class BufferedDataOutputStreamPlus extends DataOutputStreamPlus protected ByteBuffer buffer; - //Allow derived classes to specify writing to the channel - //directly shouldn't happen because they intercept via doFlush for things - //like compression or checksumming - //Another hack for this value is that it also indicates that flushing early - //should not occur, flushes aligned with buffer size are desired - //Unless... it's the last flush. Compression and checksum formats - //expect block (same as buffer size) alignment for everything except the last block - protected boolean strictFlushing = false; - public BufferedDataOutputStreamPlus(RandomAccessFile ras) { this(ras.getChannel()); @@ -132,9 +123,6 @@ public void write(byte[] b, int off, int len) throws IOException } } - // ByteBuffer to use for defensive copies - private final ByteBuffer hollowBuffer = MemoryUtil.getHollowDirectByteBuffer(); - /* * Makes a defensive copy of the incoming ByteBuffer and don't modify the position or limit * even temporarily so it is thread-safe WRT to the incoming buffer @@ -142,48 +130,20 @@ public void write(byte[] b, int off, int len) throws IOException * @see org.apache.cassandra.io.util.DataOutputPlus#write(java.nio.ByteBuffer) */ @Override - public void write(ByteBuffer toWrite) throws IOException - { - if (toWrite.hasArray()) - { - write(toWrite.array(), toWrite.arrayOffset() + toWrite.position(), toWrite.remaining()); - } - else - { - assert toWrite.isDirect(); - MemoryUtil.duplicateDirectByteBuffer(toWrite, hollowBuffer); - int toWriteRemaining = toWrite.remaining(); - - if (toWriteRemaining > buffer.remaining()) - { - if (strictFlushing) - { - writeExcessSlow(); - } - else - { - doFlush(toWriteRemaining - buffer.remaining()); - while (hollowBuffer.remaining() > buffer.capacity()) - channel.write(hollowBuffer); - } - } - - buffer.put(hollowBuffer); - } - } - - // writes anything we can't fit into the buffer - @DontInline - private void writeExcessSlow() throws IOException + public void write(ByteBuffer src) throws IOException { - int originalLimit = hollowBuffer.limit(); - while (originalLimit - hollowBuffer.position() > buffer.remaining()) + int srcPos = src.position(); + int srcCount; + int trgAvailable; + while ((srcCount = src.limit() - srcPos) > (trgAvailable = buffer.remaining())) { - hollowBuffer.limit(hollowBuffer.position() + buffer.remaining()); - buffer.put(hollowBuffer); - doFlush(originalLimit - hollowBuffer.position()); + FastByteOperations.copy(src, srcPos, buffer, buffer.position(), trgAvailable); + buffer.position(buffer.position() + trgAvailable); + srcPos += trgAvailable; + doFlush(src.limit() - srcPos); } - hollowBuffer.limit(originalLimit); + FastByteOperations.copy(src, srcPos, buffer, buffer.position(), srcCount); + buffer.position(buffer.position() + srcCount); } @Override @@ -241,25 +201,6 @@ public void writeLong(long v) throws IOException buffer.putLong(v); } - @Override - public void writeVInt(long value) throws IOException - { - writeUnsignedVInt(VIntCoding.encodeZigZag64(value)); - } - - @Override - public void writeUnsignedVInt(long value) throws IOException - { - int size = VIntCoding.computeUnsignedVIntSize(value); - if (size == 1) - { - write((int) value); - return; - } - - write(VIntCoding.encodeVInt(value, size), 0, size); - } - @Override public void writeFloat(float v) throws IOException { @@ -302,13 +243,6 @@ public void writeUTF(String s) throws IOException UnbufferedDataOutputStreamPlus.writeUTF(s, this); } - @Override - public void write(Memory memory, long offset, long length) throws IOException - { - for (ByteBuffer buffer : memory.asByteBuffers(offset, length)) - write(buffer); - } - /* * Count is the number of bytes remaining to write ignoring already remaining capacity */ @@ -338,16 +272,6 @@ public void close() throws IOException buffer = null; } - @Override - public R applyToChannel(CheckedFunction f) throws IOException - { - if (strictFlushing) - throw new UnsupportedOperationException(); - //Don't allow writes to the underlying channel while data is buffered - flush(); - return f.apply(channel); - } - public BufferedDataOutputStreamPlus order(ByteOrder order) { this.buffer.order(order); diff --git a/src/java/org/apache/cassandra/io/util/DataOutputBuffer.java b/src/java/org/apache/cassandra/io/util/DataOutputBuffer.java index 28ca4680a7fc..a6c7086b979e 100644 --- a/src/java/org/apache/cassandra/io/util/DataOutputBuffer.java +++ b/src/java/org/apache/cassandra/io/util/DataOutputBuffer.java @@ -201,6 +201,19 @@ public void close() public ByteBuffer buffer() { + return buffer(true); + } + + public ByteBuffer buffer(boolean duplicate) + { + if (!duplicate) + { + ByteBuffer buf = buffer; + buf.flip(); + buffer = null; + return buf; + } + ByteBuffer result = buffer.duplicate(); result.flip(); return result; diff --git a/src/java/org/apache/cassandra/io/util/DataOutputPlus.java b/src/java/org/apache/cassandra/io/util/DataOutputPlus.java index 16be42f889be..b94d097b6031 100644 --- a/src/java/org/apache/cassandra/io/util/DataOutputPlus.java +++ b/src/java/org/apache/cassandra/io/util/DataOutputPlus.java @@ -20,7 +20,6 @@ import java.io.DataOutput; import java.io.IOException; import java.nio.ByteBuffer; -import java.nio.channels.WritableByteChannel; import org.apache.cassandra.utils.vint.VIntCoding; @@ -33,13 +32,11 @@ public interface DataOutputPlus extends DataOutput // write the buffer without modifying its position void write(ByteBuffer buffer) throws IOException; - void write(Memory memory, long offset, long length) throws IOException; - - /** - * Safe way to operate against the underlying channel. Impossible to stash a reference to the channel - * and forget to flush - */ - R applyToChannel(CheckedFunction c) throws IOException; + default void write(Memory memory, long offset, long length) throws IOException + { + for (ByteBuffer buffer : memory.asByteBuffers(offset, length)) + write(buffer); + } default void writeVInt(long i) throws IOException { diff --git a/src/java/org/apache/cassandra/io/util/DataOutputStreamPlus.java b/src/java/org/apache/cassandra/io/util/DataOutputStreamPlus.java index 4adb6d20fcc4..e931899c0763 100644 --- a/src/java/org/apache/cassandra/io/util/DataOutputStreamPlus.java +++ b/src/java/org/apache/cassandra/io/util/DataOutputStreamPlus.java @@ -119,7 +119,7 @@ public int write(ByteBuffer src) throws IOException { int toWriteThisTime = Math.min(buf.length, toWrite - totalWritten); - ByteBufferUtil.arrayCopy(src, src.position() + totalWritten, buf, 0, toWriteThisTime); + ByteBufferUtil.copyBytes(src, src.position() + totalWritten, buf, 0, toWriteThisTime); DataOutputStreamPlus.this.write(buf, 0, toWriteThisTime); diff --git a/src/java/org/apache/cassandra/io/util/FastByteArrayInputStream.java b/src/java/org/apache/cassandra/io/util/FastByteArrayInputStream.java deleted file mode 100644 index f61546c95e85..000000000000 --- a/src/java/org/apache/cassandra/io/util/FastByteArrayInputStream.java +++ /dev/null @@ -1,249 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.cassandra.io.util; - -import java.io.ByteArrayInputStream; -import java.io.IOException; -import java.io.InputStream; - -/* - * This file has been modified from Apache Harmony's ByteArrayInputStream - * implementation. The synchronized methods of the original have been - * replaced by non-synchronized methods. This makes this certain operations - * FASTer, but also *not thread-safe*. - * - * This file remains formatted the same as the Apache Harmony original to - * make patching easier if any bug fixes are made to the Harmony version. - */ - -/** - * A specialized {@link InputStream } for reading the contents of a byte array. - * - * @see ByteArrayInputStream - */ -public class FastByteArrayInputStream extends InputStream -{ - /** - * The {@code byte} array containing the bytes to stream over. - */ - protected byte[] buf; - - /** - * The current position within the byte array. - */ - protected int pos; - - /** - * The current mark position. Initially set to 0 or the offset - * parameter within the constructor. - */ - protected int mark; - - /** - * The total number of bytes initially available in the byte array - * {@code buf}. - */ - protected int count; - - /** - * Constructs a new {@code ByteArrayInputStream} on the byte array - * {@code buf}. - * - * @param buf - * the byte array to stream over. - */ - public FastByteArrayInputStream(byte buf[]) - { - this.mark = 0; - this.buf = buf; - this.count = buf.length; - } - - /** - * Constructs a new {@code ByteArrayInputStream} on the byte array - * {@code buf} with the initial position set to {@code offset} and the - * number of bytes available set to {@code offset} + {@code length}. - * - * @param buf - * the byte array to stream over. - * @param offset - * the initial position in {@code buf} to start streaming from. - * @param length - * the number of bytes available for streaming. - */ - public FastByteArrayInputStream(byte buf[], int offset, int length) - { - this.buf = buf; - pos = offset; - mark = offset; - count = offset + length > buf.length ? buf.length : offset + length; - } - - /** - * Returns the number of bytes that are available before this stream will - * block. This method returns the number of bytes yet to be read from the - * source byte array. - * - * @return the number of bytes available before blocking. - */ - @Override - public int available() - { - return count - pos; - } - - /** - * Closes this stream and frees resources associated with this stream. - * - * @throws IOException - * if an I/O error occurs while closing this stream. - */ - @Override - public void close() throws IOException - { - // Do nothing on close, this matches JDK behaviour. - } - - /** - * Sets a mark position in this ByteArrayInputStream. The parameter - * {@code readlimit} is ignored. Sending {@code reset()} will reposition the - * stream back to the marked position. - * - * @param readlimit - * ignored. - * @see #markSupported() - * @see #reset() - */ - @Override - public void mark(int readlimit) - { - mark = pos; - } - - /** - * Indicates whether this stream supports the {@code mark()} and - * {@code reset()} methods. Returns {@code true} since this class supports - * these methods. - * - * @return always {@code true}. - * @see #mark(int) - * @see #reset() - */ - @Override - public boolean markSupported() - { - return true; - } - - /** - * Reads a single byte from the source byte array and returns it as an - * integer in the range from 0 to 255. Returns -1 if the end of the source - * array has been reached. - * - * @return the byte read or -1 if the end of this stream has been reached. - */ - @Override - public int read() - { - return pos < count ? buf[pos++] & 0xFF : -1; - } - - /** - * Reads at most {@code len} bytes from this stream and stores - * them in byte array {@code b} starting at {@code offset}. This - * implementation reads bytes from the source byte array. - * - * @param b - * the byte array in which to store the bytes read. - * @param offset - * the initial position in {@code b} to store the bytes read from - * this stream. - * @param length - * the maximum number of bytes to store in {@code b}. - * @return the number of bytes actually read or -1 if no bytes were read and - * the end of the stream was encountered. - * @throws IndexOutOfBoundsException - * if {@code offset < 0} or {@code length < 0}, or if - * {@code offset + length} is greater than the size of - * {@code b}. - * @throws NullPointerException - * if {@code b} is {@code null}. - */ - @Override - public int read(byte b[], int offset, int length) - { - if (b == null) { - throw new NullPointerException(); - } - // avoid int overflow - if (offset < 0 || offset > b.length || length < 0 - || length > b.length - offset) - { - throw new IndexOutOfBoundsException(); - } - // Are there any bytes available? - if (this.pos >= this.count) - { - return -1; - } - if (length == 0) - { - return 0; - } - - int copylen = this.count - pos < length ? this.count - pos : length; - System.arraycopy(buf, pos, b, offset, copylen); - pos += copylen; - return copylen; - } - - /** - * Resets this stream to the last marked location. This implementation - * resets the position to either the marked position, the start position - * supplied in the constructor or 0 if neither has been provided. - * - * @see #mark(int) - */ - @Override - public void reset() - { - pos = mark; - } - - /** - * Skips {@code count} number of bytes in this InputStream. Subsequent - * {@code read()}s will not return these bytes unless {@code reset()} is - * used. This implementation skips {@code count} number of bytes in the - * target stream. It does nothing and returns 0 if {@code n} is negative. - * - * @param n - * the number of bytes to skip. - * @return the number of bytes actually skipped. - */ - @Override - public long skip(long n) - { - if (n <= 0) - { - return 0; - } - int temp = pos; - pos = this.count - pos < n ? this.count : (int) (pos + n); - return pos - temp; - } -} diff --git a/src/java/org/apache/cassandra/io/util/FileUtils.java b/src/java/org/apache/cassandra/io/util/FileUtils.java index ed03715e5a22..6e5c00af6ecf 100644 --- a/src/java/org/apache/cassandra/io/util/FileUtils.java +++ b/src/java/org/apache/cassandra/io/util/FileUtils.java @@ -312,7 +312,7 @@ public static void close(Closeable... cs) throws IOException public static void close(Iterable cs) throws IOException { - IOException e = null; + Throwable e = null; for (Closeable c : cs) { try @@ -320,14 +320,14 @@ public static void close(Iterable cs) throws IOException if (c != null) c.close(); } - catch (IOException ex) + catch (Throwable ex) { - e = ex; + if (e == null) e = ex; + else e.addSuppressed(ex); logger.warn("Failed closing stream {}", c, ex); } } - if (e != null) - throw e; + maybeFail(e, IOException.class); } public static void closeQuietly(Iterable cs) diff --git a/src/java/org/apache/cassandra/io/util/Memory.java b/src/java/org/apache/cassandra/io/util/Memory.java index 0ca6aa214a6c..eaa6e919e6d1 100644 --- a/src/java/org/apache/cassandra/io/util/Memory.java +++ b/src/java/org/apache/cassandra/io/util/Memory.java @@ -417,7 +417,7 @@ public ByteBuffer asByteBuffer(long offset, int length) public void setByteBuffer(ByteBuffer buffer, long offset, int length) { checkBounds(offset, offset + length); - MemoryUtil.setByteBuffer(buffer, peer + offset, length); + MemoryUtil.setDirectByteBuffer(buffer, peer + offset, length); } public String toString() diff --git a/src/java/org/apache/cassandra/io/util/SequentialWriter.java b/src/java/org/apache/cassandra/io/util/SequentialWriter.java index 3eb1a7d81500..9ad944be3bc0 100644 --- a/src/java/org/apache/cassandra/io/util/SequentialWriter.java +++ b/src/java/org/apache/cassandra/io/util/SequentialWriter.java @@ -19,6 +19,7 @@ import java.io.File; import java.io.IOException; +import java.nio.ByteBuffer; import java.nio.channels.FileChannel; import java.nio.file.StandardOpenOption; @@ -43,6 +44,15 @@ public class SequentialWriter extends BufferedDataOutputStreamPlus implements Tr protected final FileChannel fchannel; + //Allow derived classes to specify writing to the channel + //directly shouldn't happen because they intercept via doFlush for things + //like compression or checksumming + //Another hack for this value is that it also indicates that flushing early + //should not occur, flushes aligned with buffer size are desired + //Unless... it's the last flush. Compression and checksum formats + //expect block (same as buffer size) alignment for everything except the last block + private final boolean strictFlushing; + // whether to do trickling fsync() to avoid sudden bursts of dirty buffer flushing by kernel causing read // latency spikes private final SequentialWriterOption option; @@ -388,6 +398,15 @@ public final void close() txnProxy.close(); } + public int writeDirectlyToChannel(ByteBuffer buf) throws IOException + { + if (strictFlushing) + throw new UnsupportedOperationException(); + // Don't allow writes to the underlying channel while data is buffered + flush(); + return channel.write(buf); + } + public final void finish() { txnProxy.finish(); diff --git a/src/java/org/apache/cassandra/io/util/UnbufferedDataOutputStreamPlus.java b/src/java/org/apache/cassandra/io/util/UnbufferedDataOutputStreamPlus.java index d9ef01064058..3d8321296cd1 100644 --- a/src/java/org/apache/cassandra/io/util/UnbufferedDataOutputStreamPlus.java +++ b/src/java/org/apache/cassandra/io/util/UnbufferedDataOutputStreamPlus.java @@ -371,15 +371,4 @@ public void write(ByteBuffer buf) throws IOException } } - public void write(Memory memory, long offset, long length) throws IOException - { - for (ByteBuffer buffer : memory.asByteBuffers(offset, length)) - write(buffer); - } - - @Override - public R applyToChannel(CheckedFunction f) throws IOException - { - return f.apply(channel); - } } diff --git a/src/java/org/apache/cassandra/locator/AlibabaCloudSnitch.java b/src/java/org/apache/cassandra/locator/AlibabaCloudSnitch.java new file mode 100644 index 000000000000..729e1b376393 --- /dev/null +++ b/src/java/org/apache/cassandra/locator/AlibabaCloudSnitch.java @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.cassandra.locator; + +import java.io.DataInputStream; +import java.io.FilterInputStream; +import java.io.IOException; +import java.net.HttpURLConnection; +import java.net.MalformedURLException; +import java.net.SocketTimeoutException; +import java.net.URL; +import java.nio.charset.StandardCharsets; +import java.util.Map; +import org.apache.cassandra.db.SystemKeyspace; +import org.apache.cassandra.exceptions.ConfigurationException; +import org.apache.cassandra.gms.ApplicationState; +import org.apache.cassandra.gms.EndpointState; +import org.apache.cassandra.gms.Gossiper; +import org.apache.cassandra.io.util.FileUtils; +import org.apache.cassandra.utils.FBUtilities; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * A snitch that assumes an ECS region is a DC and an ECS availability_zone + * is a rack. This information is available in the config for the node. the + * format of the zone-id is like :cn-hangzhou-a where cn means china, hangzhou + * means the hangzhou region, a means the az id. We use cn-hangzhou as the dc, + * and f as the zone-id. + */ +public class AlibabaCloudSnitch extends AbstractNetworkTopologySnitch +{ + protected static final Logger logger = LoggerFactory.getLogger(AlibabaCloudSnitch.class); + protected static final String ZONE_NAME_QUERY_URL = "http://100.100.100.200/latest/meta-data/zone-id"; + private static final String DEFAULT_DC = "UNKNOWN-DC"; + private static final String DEFAULT_RACK = "UNKNOWN-RACK"; + private Map> savedEndpoints; + protected String ecsZone; + protected String ecsRegion; + + private static final int HTTP_CONNECT_TIMEOUT = 30000; + + + public AlibabaCloudSnitch() throws MalformedURLException, IOException + { + String response = alibabaApiCall(ZONE_NAME_QUERY_URL); + String[] splits = response.split("/"); + String az = splits[splits.length - 1]; + + // Split "us-central1-a" or "asia-east1-a" into "us-central1"/"a" and "asia-east1"/"a". + splits = az.split("-"); + ecsZone = splits[splits.length - 1]; + + int lastRegionIndex = az.lastIndexOf("-"); + ecsRegion = az.substring(0, lastRegionIndex); + + String datacenterSuffix = (new SnitchProperties()).get("dc_suffix", ""); + ecsRegion = ecsRegion.concat(datacenterSuffix); + logger.info("AlibabaSnitch using region: {}, zone: {}.", ecsRegion, ecsZone); + + } + + String alibabaApiCall(String url) throws ConfigurationException, IOException, SocketTimeoutException + { + // Populate the region and zone by introspection, fail if 404 on metadata + HttpURLConnection conn = (HttpURLConnection) new URL(url).openConnection(); + DataInputStream d = null; + try + { + conn.setConnectTimeout(HTTP_CONNECT_TIMEOUT); + conn.setRequestMethod("GET"); + + int code = conn.getResponseCode(); + if (code != HttpURLConnection.HTTP_OK) + throw new ConfigurationException("AlibabaSnitch was unable to execute the API call. Not an ecs node? and the returun code is " + code); + + // Read the information. I wish I could say (String) conn.getContent() here... + int cl = conn.getContentLength(); + byte[] b = new byte[cl]; + d = new DataInputStream((FilterInputStream) conn.getContent()); + d.readFully(b); + return new String(b, StandardCharsets.UTF_8); + } + catch (SocketTimeoutException e) + { + throw new SocketTimeoutException("Timeout occurred reading a response from the Alibaba ECS metadata"); + } + finally + { + FileUtils.close(d); + conn.disconnect(); + } + } + + @Override + public String getRack(InetAddressAndPort endpoint) + { + if (endpoint.equals(FBUtilities.getBroadcastAddressAndPort())) + return ecsZone; + EndpointState state = Gossiper.instance.getEndpointStateForEndpoint(endpoint); + if (state == null || state.getApplicationState(ApplicationState.RACK) == null) + { + if (savedEndpoints == null) + savedEndpoints = SystemKeyspace.loadDcRackInfo(); + if (savedEndpoints.containsKey(endpoint)) + return savedEndpoints.get(endpoint).get("rack"); + return DEFAULT_RACK; + } + return state.getApplicationState(ApplicationState.RACK).value; + + } + + @Override + public String getDatacenter(InetAddressAndPort endpoint) + { + if (endpoint.equals(FBUtilities.getBroadcastAddressAndPort())) + return ecsRegion; + EndpointState state = Gossiper.instance.getEndpointStateForEndpoint(endpoint); + if (state == null || state.getApplicationState(ApplicationState.DC) == null) + { + if (savedEndpoints == null) + savedEndpoints = SystemKeyspace.loadDcRackInfo(); + if (savedEndpoints.containsKey(endpoint)) + return savedEndpoints.get(endpoint).get("data_center"); + return DEFAULT_DC; + } + return state.getApplicationState(ApplicationState.DC).value; + + } + +} diff --git a/src/java/org/apache/cassandra/locator/DynamicEndpointSnitch.java b/src/java/org/apache/cassandra/locator/DynamicEndpointSnitch.java index ddc8fba276f5..0b241ce0d519 100644 --- a/src/java/org/apache/cassandra/locator/DynamicEndpointSnitch.java +++ b/src/java/org/apache/cassandra/locator/DynamicEndpointSnitch.java @@ -35,6 +35,7 @@ import org.apache.cassandra.gms.EndpointState; import org.apache.cassandra.gms.Gossiper; import org.apache.cassandra.gms.VersionedValue; +import org.apache.cassandra.net.LatencySubscribers; import org.apache.cassandra.net.MessagingService; import org.apache.cassandra.service.StorageService; import org.apache.cassandra.utils.FBUtilities; @@ -43,7 +44,7 @@ /** * A dynamic snitch that sorts endpoints by latency with an adapted phi failure detector */ -public class DynamicEndpointSnitch extends AbstractEndpointSnitch implements ILatencySubscriber, DynamicEndpointSnitchMBean +public class DynamicEndpointSnitch extends AbstractEndpointSnitch implements LatencySubscribers.Subscriber, DynamicEndpointSnitchMBean { private static final boolean USE_SEVERITY = !Boolean.getBoolean("cassandra.ignore_dynamic_snitch_severity"); @@ -253,7 +254,7 @@ public int compareEndpoints(InetAddressAndPort target, Replica a1, Replica a2) throw new UnsupportedOperationException("You shouldn't wrap the DynamicEndpointSnitch (within itself or otherwise)"); } - public void receiveTiming(InetAddressAndPort host, long latency) // this is cheap + public void receiveTiming(InetAddressAndPort host, long latency, TimeUnit unit) // this is cheap { ExponentiallyDecayingReservoir sample = samples.get(host); if (sample == null) @@ -263,7 +264,7 @@ public void receiveTiming(InetAddressAndPort host, long latency) // this is chea if (sample == null) sample = maybeNewSample; } - sample.update(latency); + sample.update(unit.toMillis(latency)); } private void updateScores() // this is expensive @@ -274,7 +275,7 @@ private void updateScores() // this is expensive { if (MessagingService.instance() != null) { - MessagingService.instance().register(this); + MessagingService.instance().latencySubscribers.subscribe(this); registered = true; } diff --git a/src/java/org/apache/cassandra/locator/Endpoints.java b/src/java/org/apache/cassandra/locator/Endpoints.java index a2bad6ce6fb2..c1a928214487 100644 --- a/src/java/org/apache/cassandra/locator/Endpoints.java +++ b/src/java/org/apache/cassandra/locator/Endpoints.java @@ -21,11 +21,15 @@ import org.apache.cassandra.locator.ReplicaCollection.Builder.Conflict; import org.apache.cassandra.utils.FBUtilities; +import java.util.AbstractList; import java.util.Collection; +import java.util.List; import java.util.Map; import java.util.Objects; import java.util.Set; +import com.google.common.collect.Lists; + /** * A collection of Endpoints for a given ring position. This will typically reside in a ReplicaLayout, * representing some subset of the endpoints for the Token or Range @@ -52,6 +56,22 @@ public Set endpoints() return byEndpoint().keySet(); } + public List endpointList() + { + return new AbstractList() + { + public InetAddressAndPort get(int index) + { + return list.get(index).endpoint(); + } + + public int size() + { + return list.size; + } + }; + } + public Map byEndpoint() { ReplicaMap map = byEndpoint; diff --git a/src/java/org/apache/cassandra/locator/InetAddressAndPort.java b/src/java/org/apache/cassandra/locator/InetAddressAndPort.java index a47c72a71696..6821f139867b 100644 --- a/src/java/org/apache/cassandra/locator/InetAddressAndPort.java +++ b/src/java/org/apache/cassandra/locator/InetAddressAndPort.java @@ -15,16 +15,24 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package org.apache.cassandra.locator; +import java.io.IOException; import java.io.Serializable; +import java.net.Inet4Address; +import java.net.Inet6Address; import java.net.InetAddress; import java.net.UnknownHostException; +import java.nio.ByteBuffer; import com.google.common.base.Preconditions; import com.google.common.net.HostAndPort; +import org.apache.cassandra.io.IVersionedSerializer; +import org.apache.cassandra.io.util.DataInputPlus; +import org.apache.cassandra.io.util.DataOutputPlus; +import org.apache.cassandra.net.MessagingService; +import org.apache.cassandra.utils.ByteBufferUtil; import org.apache.cassandra.utils.FBUtilities; import org.apache.cassandra.utils.FastByteOperations; @@ -41,6 +49,7 @@ * need to sometimes return a port and sometimes not. * */ +@SuppressWarnings("UnstableApiUsage") public final class InetAddressAndPort implements Comparable, Serializable { private static final long serialVersionUID = 0; @@ -65,6 +74,11 @@ private InetAddressAndPort(InetAddress address, byte[] addressBytes, int port) this.addressBytes = addressBytes; } + public InetAddressAndPort withPort(int port) + { + return new InetAddressAndPort(address, addressBytes, port); + } + private static void validatePortRange(int port) { if (port < 0 | port > 65535) @@ -127,7 +141,7 @@ public String toString(boolean withPort) { if (withPort) { - return HostAndPort.fromParts(address.getHostAddress(), port).toString(); + return toString(address, port); } else { @@ -135,6 +149,11 @@ public String toString(boolean withPort) } } + public static String toString(InetAddress address, int port) + { + return HostAndPort.fromParts(address.getHostAddress(), port).toString(); + } + public static InetAddressAndPort getByName(String name) throws UnknownHostException { return getByNameOverrideDefaults(name, null); @@ -144,8 +163,6 @@ public static InetAddressAndPort getByName(String name) throws UnknownHostExcept * * @param name Hostname + optional ports string * @param port Port to connect on, overridden by values in hostname string, defaults to DatabaseDescriptor default if not specified anywhere. - * @return - * @throws UnknownHostException */ public static InetAddressAndPort getByNameOverrideDefaults(String name, Integer port) throws UnknownHostException { @@ -201,4 +218,114 @@ public static void initializeDefaultPort(int port) { defaultPort = port; } + + static int getDefaultPort() + { + return defaultPort; + } + + /* + * As of version 4.0 the endpoint description includes a port number as an unsigned short + */ + public static final class Serializer implements IVersionedSerializer + { + public static final int MAXIMUM_SIZE = 19; + + // We put the static instance here, to avoid complexity with dtests. + // InetAddressAndPort is one of the only classes we share between instances, which is possible cleanly + // because it has no type-dependencies in its public API, however Serializer requires DataOutputPlus, which requires... + // and the chain becomes quite unwieldy + public static final Serializer inetAddressAndPortSerializer = new Serializer(); + + private Serializer() {} + + public void serialize(InetAddressAndPort endpoint, DataOutputPlus out, int version) throws IOException + { + byte[] buf = endpoint.addressBytes; + + if (version >= MessagingService.VERSION_40) + { + out.writeByte(buf.length + 2); + out.write(buf); + out.writeShort(endpoint.port); + } + else + { + out.writeByte(buf.length); + out.write(buf); + } + } + + public InetAddressAndPort deserialize(DataInputPlus in, int version) throws IOException + { + int size = in.readByte() & 0xFF; + switch(size) + { + //The original pre-4.0 serialiation of just an address + case 4: + case 16: + { + byte[] bytes = new byte[size]; + in.readFully(bytes, 0, bytes.length); + return getByAddress(bytes); + } + //Address and one port + case 6: + case 18: + { + byte[] bytes = new byte[size - 2]; + in.readFully(bytes); + + int port = in.readShort() & 0xFFFF; + return getByAddressOverrideDefaults(InetAddress.getByAddress(bytes), bytes, port); + } + default: + throw new AssertionError("Unexpected size " + size); + + } + } + + /** + * Extract {@link InetAddressAndPort} from the provided {@link ByteBuffer} without altering its state. + */ + public InetAddressAndPort extract(ByteBuffer buf, int position) throws IOException + { + int size = buf.get(position++) & 0xFF; + if (size == 4 || size == 16) + { + byte[] bytes = new byte[size]; + ByteBufferUtil.copyBytes(buf, position, bytes, 0, size); + return getByAddress(bytes); + } + else if (size == 6 || size == 18) + { + byte[] bytes = new byte[size - 2]; + ByteBufferUtil.copyBytes(buf, position, bytes, 0, size - 2); + position += (size - 2); + int port = buf.getShort(position) & 0xFFFF; + return getByAddressOverrideDefaults(InetAddress.getByAddress(bytes), bytes, port); + } + + throw new AssertionError("Unexpected pre-4.0 InetAddressAndPort size " + size); + } + + public long serializedSize(InetAddressAndPort from, int version) + { + //4.0 includes a port number + if (version >= MessagingService.VERSION_40) + { + if (from.address instanceof Inet4Address) + return 1 + 4 + 2; + assert from.address instanceof Inet6Address; + return 1 + 16 + 2; + } + else + { + if (from.address instanceof Inet4Address) + return 1 + 4; + assert from.address instanceof Inet6Address; + return 1 + 16; + } + } + } } diff --git a/src/java/org/apache/cassandra/locator/ReconnectableSnitchHelper.java b/src/java/org/apache/cassandra/locator/ReconnectableSnitchHelper.java index 547901086d9a..dea8c76f4e4e 100644 --- a/src/java/org/apache/cassandra/locator/ReconnectableSnitchHelper.java +++ b/src/java/org/apache/cassandra/locator/ReconnectableSnitchHelper.java @@ -22,13 +22,16 @@ import com.google.common.annotations.VisibleForTesting; -import org.apache.cassandra.config.DatabaseDescriptor; import org.apache.cassandra.gms.*; +import org.apache.cassandra.net.ConnectionCategory; import org.apache.cassandra.net.MessagingService; +import org.apache.cassandra.net.OutboundConnectionSettings; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import static org.apache.cassandra.net.ConnectionType.SMALL_MESSAGES; + /** * Sidekick helper for snitches that want to reconnect from one IP addr for a node to another. * Typically, this is for situations like EC2 where a node will have a public address and a private address, @@ -63,16 +66,15 @@ private void reconnect(InetAddressAndPort publicAddress, VersionedValue localAdd @VisibleForTesting static void reconnect(InetAddressAndPort publicAddress, InetAddressAndPort localAddress, IEndpointSnitch snitch, String localDc) { - if (!DatabaseDescriptor.getInternodeAuthenticator().authenticate(publicAddress.address, MessagingService.instance().portFor(publicAddress))) + if (!new OutboundConnectionSettings(publicAddress, localAddress).withDefaults(ConnectionCategory.MESSAGING).authenticate()) { logger.debug("InternodeAuthenticator said don't reconnect to {} on {}", publicAddress, localAddress); return; } - if (snitch.getDatacenter(publicAddress).equals(localDc) - && !MessagingService.instance().getCurrentEndpoint(publicAddress).equals(localAddress)) + if (snitch.getDatacenter(publicAddress).equals(localDc)) { - MessagingService.instance().reconnectWithNewIp(publicAddress, localAddress); + MessagingService.instance().maybeReconnectWithNewIp(publicAddress, localAddress); logger.debug("Initiated reconnect to an Internal IP {} for the {}", localAddress, publicAddress); } } diff --git a/src/java/org/apache/cassandra/metrics/ClientMetrics.java b/src/java/org/apache/cassandra/metrics/ClientMetrics.java index a80033ab0e73..7599096ae954 100644 --- a/src/java/org/apache/cassandra/metrics/ClientMetrics.java +++ b/src/java/org/apache/cassandra/metrics/ClientMetrics.java @@ -19,6 +19,7 @@ package org.apache.cassandra.metrics; import java.util.*; +import java.util.concurrent.atomic.AtomicInteger; import com.codahale.metrics.Gauge; import com.codahale.metrics.Meter; @@ -40,6 +41,10 @@ public final class ClientMetrics private Meter authSuccess; private Meter authFailure; + private AtomicInteger pausedConnections; + private Gauge pausedConnectionsGauge; + private Meter requestDiscarded; + private ClientMetrics() { } @@ -54,6 +59,11 @@ public void markAuthFailure() authFailure.mark(); } + public void pauseConnection() { pausedConnections.incrementAndGet(); } + public void unpauseConnection() { pausedConnections.decrementAndGet(); } + + public void markRequestDiscarded() { requestDiscarded.mark(); } + public List allConnectedClients() { List clients = new ArrayList<>(); @@ -79,6 +89,10 @@ public synchronized void init(Collection servers) authSuccess = registerMeter("AuthSuccess"); authFailure = registerMeter("AuthFailure"); + pausedConnections = new AtomicInteger(); + pausedConnectionsGauge = registerGauge("PausedConnections", pausedConnections::get); + requestDiscarded = registerMeter("RequestDiscarded"); + initialized = true; } diff --git a/src/java/org/apache/cassandra/metrics/CommitLogMetrics.java b/src/java/org/apache/cassandra/metrics/CommitLogMetrics.java index 08c1c8e46d28..4473760d0e5f 100644 --- a/src/java/org/apache/cassandra/metrics/CommitLogMetrics.java +++ b/src/java/org/apache/cassandra/metrics/CommitLogMetrics.java @@ -20,7 +20,7 @@ import com.codahale.metrics.Gauge; import com.codahale.metrics.Timer; import org.apache.cassandra.db.commitlog.AbstractCommitLogService; -import org.apache.cassandra.db.commitlog.AbstractCommitLogSegmentManager; +import org.apache.cassandra.db.commitlog.CommitLogSegmentManager; import static org.apache.cassandra.metrics.CassandraMetricsRegistry.Metrics; @@ -48,7 +48,7 @@ public CommitLogMetrics() waitingOnCommit = Metrics.timer(factory.createMetricName("WaitingOnCommit")); } - public void attach(final AbstractCommitLogService service, final AbstractCommitLogSegmentManager segmentManager) + public void attach(final AbstractCommitLogService service, final CommitLogSegmentManager segmentManager) { completedTasks = Metrics.register(factory.createMetricName("CompletedTasks"), new Gauge() { diff --git a/src/java/org/apache/cassandra/metrics/ConnectionMetrics.java b/src/java/org/apache/cassandra/metrics/ConnectionMetrics.java deleted file mode 100644 index 3655a404d080..000000000000 --- a/src/java/org/apache/cassandra/metrics/ConnectionMetrics.java +++ /dev/null @@ -1,154 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.cassandra.metrics; - -import com.codahale.metrics.Gauge; -import com.codahale.metrics.Meter; -import org.apache.cassandra.net.async.OutboundMessagingPool; - -import static org.apache.cassandra.metrics.CassandraMetricsRegistry.Metrics; - -import org.apache.cassandra.locator.InetAddressAndPort; - -/** - * Metrics for internode connections. - */ -public class ConnectionMetrics -{ - public static final String TYPE_NAME = "Connection"; - - /** Total number of timeouts happened on this node */ - public static final Meter totalTimeouts = Metrics.meter(DefaultNameFactory.createMetricName(TYPE_NAME, "TotalTimeouts", null)); - - public final String address; - /** Pending tasks for large message TCP Connections */ - public final Gauge largeMessagePendingTasks; - /** Completed tasks for large message TCP Connections */ - public final Gauge largeMessageCompletedTasks; - /** Dropped tasks for large message TCP Connections */ - public final Gauge largeMessageDroppedTasks; - /** Pending tasks for small message TCP Connections */ - public final Gauge smallMessagePendingTasks; - /** Completed tasks for small message TCP Connections */ - public final Gauge smallMessageCompletedTasks; - /** Dropped tasks for small message TCP Connections */ - public final Gauge smallMessageDroppedTasks; - /** Pending tasks for gossip message TCP Connections */ - public final Gauge gossipMessagePendingTasks; - /** Completed tasks for gossip message TCP Connections */ - public final Gauge gossipMessageCompletedTasks; - /** Dropped tasks for gossip message TCP Connections */ - public final Gauge gossipMessageDroppedTasks; - - /** Number of timeouts for specific IP */ - public final Meter timeouts; - - private final MetricNameFactory factory; - - /** - * Create metrics for given connection pool. - * - * @param ip IP address to use for metrics label - */ - public ConnectionMetrics(InetAddressAndPort ip, final OutboundMessagingPool messagingPool) - { - // ipv6 addresses will contain colons, which are invalid in a JMX ObjectName - address = ip.toString().replace(':', '.'); - - factory = new DefaultNameFactory("Connection", address); - - largeMessagePendingTasks = Metrics.register(factory.createMetricName("LargeMessagePendingTasks"), new Gauge() - { - public Integer getValue() - { - return messagingPool.largeMessageChannel.getPendingMessages(); - } - }); - largeMessageCompletedTasks = Metrics.register(factory.createMetricName("LargeMessageCompletedTasks"), new Gauge() - { - public Long getValue() - { - return messagingPool.largeMessageChannel.getCompletedMessages(); - } - }); - largeMessageDroppedTasks = Metrics.register(factory.createMetricName("LargeMessageDroppedTasks"), new Gauge() - { - public Long getValue() - { - return messagingPool.largeMessageChannel.getDroppedMessages(); - } - }); - smallMessagePendingTasks = Metrics.register(factory.createMetricName("SmallMessagePendingTasks"), new Gauge() - { - public Integer getValue() - { - return messagingPool.smallMessageChannel.getPendingMessages(); - } - }); - smallMessageCompletedTasks = Metrics.register(factory.createMetricName("SmallMessageCompletedTasks"), new Gauge() - { - public Long getValue() - { - return messagingPool.smallMessageChannel.getCompletedMessages(); - } - }); - smallMessageDroppedTasks = Metrics.register(factory.createMetricName("SmallMessageDroppedTasks"), new Gauge() - { - public Long getValue() - { - return messagingPool.smallMessageChannel.getDroppedMessages(); - } - }); - gossipMessagePendingTasks = Metrics.register(factory.createMetricName("GossipMessagePendingTasks"), new Gauge() - { - public Integer getValue() - { - return messagingPool.gossipChannel.getPendingMessages(); - } - }); - gossipMessageCompletedTasks = Metrics.register(factory.createMetricName("GossipMessageCompletedTasks"), new Gauge() - { - public Long getValue() - { - return messagingPool.gossipChannel.getCompletedMessages(); - } - }); - gossipMessageDroppedTasks = Metrics.register(factory.createMetricName("GossipMessageDroppedTasks"), new Gauge() - { - public Long getValue() - { - return messagingPool.gossipChannel.getDroppedMessages(); - } - }); - timeouts = Metrics.meter(factory.createMetricName("Timeouts")); - } - - public void release() - { - Metrics.remove(factory.createMetricName("LargeMessagePendingTasks")); - Metrics.remove(factory.createMetricName("LargeMessageCompletedTasks")); - Metrics.remove(factory.createMetricName("LargeMessageDroppedTasks")); - Metrics.remove(factory.createMetricName("SmallMessagePendingTasks")); - Metrics.remove(factory.createMetricName("SmallMessageCompletedTasks")); - Metrics.remove(factory.createMetricName("SmallMessageDroppedTasks")); - Metrics.remove(factory.createMetricName("GossipMessagePendingTasks")); - Metrics.remove(factory.createMetricName("GossipMessageCompletedTasks")); - Metrics.remove(factory.createMetricName("GossipMessageDroppedTasks")); - Metrics.remove(factory.createMetricName("Timeouts")); - } -} diff --git a/src/java/org/apache/cassandra/metrics/DroppedMessageMetrics.java b/src/java/org/apache/cassandra/metrics/DroppedMessageMetrics.java index 794fa9cb8e77..8c227783f559 100644 --- a/src/java/org/apache/cassandra/metrics/DroppedMessageMetrics.java +++ b/src/java/org/apache/cassandra/metrics/DroppedMessageMetrics.java @@ -21,6 +21,7 @@ import com.codahale.metrics.Timer; import org.apache.cassandra.net.MessagingService; +import org.apache.cassandra.net.Verb; import static org.apache.cassandra.metrics.CassandraMetricsRegistry.Metrics; @@ -38,7 +39,7 @@ public class DroppedMessageMetrics /** The cross node dropped latency */ public final Timer crossNodeDroppedLatency; - public DroppedMessageMetrics(MessagingService.Verb verb) + public DroppedMessageMetrics(Verb verb) { this(new DefaultNameFactory("DroppedMessage", verb.toString())); } diff --git a/src/java/org/apache/cassandra/metrics/FrequencySampler.java b/src/java/org/apache/cassandra/metrics/FrequencySampler.java index c09434714d7a..8a8918b9fa57 100644 --- a/src/java/org/apache/cassandra/metrics/FrequencySampler.java +++ b/src/java/org/apache/cassandra/metrics/FrequencySampler.java @@ -26,6 +26,8 @@ import com.clearspring.analytics.stream.StreamSummary; +import static java.util.concurrent.TimeUnit.MILLISECONDS; + /** * Find the most frequent sample. A sample adds to the sum of its key ie *

add("x", 10); and add("x", 20); will result in "x" = 30

This uses StreamSummary to only store the @@ -37,7 +39,7 @@ public abstract class FrequencySampler extends Sampler { private static final Logger logger = LoggerFactory.getLogger(FrequencySampler.class); - private long endTimeMillis = -1; + private long endTimeNanos = -1; private StreamSummary summary; @@ -51,10 +53,10 @@ public abstract class FrequencySampler extends Sampler */ public synchronized void beginSampling(int capacity, int durationMillis) { - if (endTimeMillis == -1 || clock.currentTimeMillis() > endTimeMillis) + if (endTimeNanos == -1 || clock.now() > endTimeNanos) { - summary = new StreamSummary(capacity); - endTimeMillis = clock.currentTimeMillis() + durationMillis; + summary = new StreamSummary<>(capacity); + endTimeNanos = clock.now() + MILLISECONDS.toNanos(durationMillis); } else throw new RuntimeException("Sampling already in progress"); @@ -67,9 +69,9 @@ public synchronized void beginSampling(int capacity, int durationMillis) public synchronized List> finishSampling(int count) { List> results = Collections.emptyList(); - if (endTimeMillis != -1) + if (endTimeNanos != -1) { - endTimeMillis = -1; + endTimeNanos = -1; results = summary.topK(count) .stream() .map(c -> new Sample(c.getItem(), c.getCount(), c.getError())) @@ -82,7 +84,7 @@ protected synchronized void insert(final T item, final long value) { // samplerExecutor is single threaded but still need // synchronization against jmx calls to finishSampling - if (value > 0 && clock.currentTimeMillis() <= endTimeMillis) + if (value > 0 && clock.now() <= endTimeNanos) { try { @@ -96,7 +98,7 @@ protected synchronized void insert(final T item, final long value) public boolean isEnabled() { - return endTimeMillis != -1 && clock.currentTimeMillis() <= endTimeMillis; + return endTimeNanos != -1 && clock.now() <= endTimeNanos; } } diff --git a/src/java/org/apache/cassandra/metrics/InternodeInboundMetrics.java b/src/java/org/apache/cassandra/metrics/InternodeInboundMetrics.java new file mode 100644 index 000000000000..cc3c1c0d2d73 --- /dev/null +++ b/src/java/org/apache/cassandra/metrics/InternodeInboundMetrics.java @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.cassandra.metrics; + +import com.codahale.metrics.Gauge; +import org.apache.cassandra.locator.InetAddressAndPort; +import org.apache.cassandra.net.InboundMessageHandlers; +import org.apache.cassandra.metrics.CassandraMetricsRegistry.MetricName; + +/** + * Metrics for internode connections. + */ +public class InternodeInboundMetrics +{ + private final MetricName corruptFramesRecovered; + private final MetricName corruptFramesUnrecovered; + private final MetricName errorBytes; + private final MetricName errorCount; + private final MetricName expiredBytes; + private final MetricName expiredCount; + private final MetricName pendingBytes; + private final MetricName pendingCount; + private final MetricName processedBytes; + private final MetricName processedCount; + private final MetricName receivedBytes; + private final MetricName receivedCount; + private final MetricName throttledCount; + private final MetricName throttledNanos; + + /** + * Create metrics for given inbound message handlers. + * + * @param peer IP address and port to use for metrics label + */ + public InternodeInboundMetrics(InetAddressAndPort peer, InboundMessageHandlers handlers) + { + // ipv6 addresses will contain colons, which are invalid in a JMX ObjectName + MetricNameFactory factory = new DefaultNameFactory("InboundConnection", peer.toString().replace(':', '_')); + + register(corruptFramesRecovered = factory.createMetricName("CorruptFramesRecovered"), handlers::corruptFramesRecovered); + register(corruptFramesUnrecovered = factory.createMetricName("CorruptFramesUnrecovered"), handlers::corruptFramesUnrecovered); + register(errorBytes = factory.createMetricName("ErrorBytes"), handlers::errorBytes); + register(errorCount = factory.createMetricName("ErrorCount"), handlers::errorCount); + register(expiredBytes = factory.createMetricName("ExpiredBytes"), handlers::expiredBytes); + register(expiredCount = factory.createMetricName("ExpiredCount"), handlers::expiredCount); + register(pendingBytes = factory.createMetricName("ScheduledBytes"), handlers::scheduledBytes); + register(pendingCount = factory.createMetricName("ScheduledCount"), handlers::scheduledCount); + register(processedBytes = factory.createMetricName("ProcessedBytes"), handlers::processedBytes); + register(processedCount = factory.createMetricName("ProcessedCount"), handlers::processedCount); + register(receivedBytes = factory.createMetricName("ReceivedBytes"), handlers::receivedBytes); + register(receivedCount = factory.createMetricName("ReceivedCount"), handlers::receivedCount); + register(throttledCount = factory.createMetricName("ThrottledCount"), handlers::throttledCount); + register(throttledNanos = factory.createMetricName("ThrottledNanos"), handlers::throttledNanos); + } + + public void release() + { + remove(corruptFramesRecovered); + remove(corruptFramesUnrecovered); + remove(errorBytes); + remove(errorCount); + remove(expiredBytes); + remove(expiredCount); + remove(pendingBytes); + remove(pendingCount); + remove(processedBytes); + remove(processedCount); + remove(receivedBytes); + remove(receivedCount); + remove(throttledCount); + remove(throttledNanos); + } + + private static void register(MetricName name, Gauge gauge) + { + CassandraMetricsRegistry.Metrics.register(name, gauge); + } + + private static void remove(MetricName name) + { + CassandraMetricsRegistry.Metrics.remove(name); + } +} diff --git a/src/java/org/apache/cassandra/metrics/InternodeOutboundMetrics.java b/src/java/org/apache/cassandra/metrics/InternodeOutboundMetrics.java new file mode 100644 index 000000000000..f04b42877ba4 --- /dev/null +++ b/src/java/org/apache/cassandra/metrics/InternodeOutboundMetrics.java @@ -0,0 +1,205 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.cassandra.metrics; + +import com.codahale.metrics.Gauge; +import com.codahale.metrics.Meter; +import org.apache.cassandra.net.OutboundConnections; + +import static org.apache.cassandra.metrics.CassandraMetricsRegistry.Metrics; + +import org.apache.cassandra.locator.InetAddressAndPort; + +/** + * Metrics for internode connections. + */ +public class InternodeOutboundMetrics +{ + public static final String TYPE_NAME = "Connection"; + + /** Total number of callbacks that were not completed successfully for messages that were sent to this node + * TODO this was always broken, as it never counted those messages without callbacks? So perhaps we can redefine it. */ + public static final Meter totalExpiredCallbacks = Metrics.meter(DefaultNameFactory.createMetricName(TYPE_NAME, "TotalTimeouts", null)); + + /** Number of timeouts for specific IP */ + public final Meter expiredCallbacks; + + public final String address; + /** Pending tasks for large message TCP Connections */ + public final Gauge largeMessagePendingTasks; + /** Pending bytes for large message TCP Connections */ + public final Gauge largeMessagePendingBytes; + /** Completed tasks for large message TCP Connections */ + public final Gauge largeMessageCompletedTasks; + /** Completed bytes for large message TCP Connections */ + public final Gauge largeMessageCompletedBytes; + /** Dropped tasks for large message TCP Connections */ + public final Gauge largeMessageDropped; + /** Dropped tasks because of timeout for large message TCP Connections */ + public final Gauge largeMessageDroppedTasksDueToTimeout; + /** Dropped bytes because of timeout for large message TCP Connections */ + public final Gauge largeMessageDroppedBytesDueToTimeout; + /** Dropped tasks because of overload for large message TCP Connections */ + public final Gauge largeMessageDroppedTasksDueToOverload; + /** Dropped bytes because of overload for large message TCP Connections */ + public final Gauge largeMessageDroppedBytesDueToOverload; + /** Dropped tasks because of error for large message TCP Connections */ + public final Gauge largeMessageDroppedTasksDueToError; + /** Dropped bytes because of error for large message TCP Connections */ + public final Gauge largeMessageDroppedBytesDueToError; + /** Pending tasks for small message TCP Connections */ + public final Gauge smallMessagePendingTasks; + /** Pending bytes for small message TCP Connections */ + public final Gauge smallMessagePendingBytes; + /** Completed tasks for small message TCP Connections */ + public final Gauge smallMessageCompletedTasks; + /** Completed bytes for small message TCP Connections */ + public final Gauge smallMessageCompletedBytes; + /** Dropped tasks for small message TCP Connections */ + public final Gauge smallMessageDroppedTasks; + /** Dropped tasks because of timeout for small message TCP Connections */ + public final Gauge smallMessageDroppedTasksDueToTimeout; + /** Dropped bytes because of timeout for small message TCP Connections */ + public final Gauge smallMessageDroppedBytesDueToTimeout; + /** Dropped tasks because of overload for small message TCP Connections */ + public final Gauge smallMessageDroppedTasksDueToOverload; + /** Dropped bytes because of overload for small message TCP Connections */ + public final Gauge smallMessageDroppedBytesDueToOverload; + /** Dropped tasks because of error for small message TCP Connections */ + public final Gauge smallMessageDroppedTasksDueToError; + /** Dropped bytes because of error for small message TCP Connections */ + public final Gauge smallMessageDroppedBytesDueToError; + /** Pending tasks for small message TCP Connections */ + public final Gauge urgentMessagePendingTasks; + /** Pending bytes for urgent message TCP Connections */ + public final Gauge urgentMessagePendingBytes; + /** Completed tasks for urgent message TCP Connections */ + public final Gauge urgentMessageCompletedTasks; + /** Completed bytes for urgent message TCP Connections */ + public final Gauge urgentMessageCompletedBytes; + /** Dropped tasks for urgent message TCP Connections */ + public final Gauge urgentMessageDroppedTasks; + /** Dropped tasks because of timeout for urgent message TCP Connections */ + public final Gauge urgentMessageDroppedTasksDueToTimeout; + /** Dropped bytes because of timeout for urgent message TCP Connections */ + public final Gauge urgentMessageDroppedBytesDueToTimeout; + /** Dropped tasks because of overload for urgent message TCP Connections */ + public final Gauge urgentMessageDroppedTasksDueToOverload; + /** Dropped bytes because of overload for urgent message TCP Connections */ + public final Gauge urgentMessageDroppedBytesDueToOverload; + /** Dropped tasks because of error for urgent message TCP Connections */ + public final Gauge urgentMessageDroppedTasksDueToError; + /** Dropped bytes because of error for urgent message TCP Connections */ + public final Gauge urgentMessageDroppedBytesDueToError; + + private final MetricNameFactory factory; + + /** + * Create metrics for given connection pool. + * + * @param ip IP address to use for metrics label + */ + public InternodeOutboundMetrics(InetAddressAndPort ip, final OutboundConnections messagingPool) + { + // ipv6 addresses will contain colons, which are invalid in a JMX ObjectName + address = ip.toString().replace(':', '_'); + + factory = new DefaultNameFactory("Connection", address); + + largeMessagePendingTasks = Metrics.register(factory.createMetricName("LargeMessagePendingTasks"), messagingPool.large::pendingCount); + largeMessagePendingBytes = Metrics.register(factory.createMetricName("LargeMessagePendingBytes"), messagingPool.large::pendingBytes); + largeMessageCompletedTasks = Metrics.register(factory.createMetricName("LargeMessageCompletedTasks"),messagingPool.large::sentCount); + largeMessageCompletedBytes = Metrics.register(factory.createMetricName("LargeMessageCompletedBytes"),messagingPool.large::sentBytes); + largeMessageDropped = Metrics.register(factory.createMetricName("LargeMessageDroppedTasks"), messagingPool.large::dropped); + largeMessageDroppedTasksDueToOverload = Metrics.register(factory.createMetricName("LargeMessageDroppedTasksDueToOverload"), messagingPool.large::overloadedCount); + largeMessageDroppedBytesDueToOverload = Metrics.register(factory.createMetricName("LargeMessageDroppedBytesDueToOverload"), messagingPool.large::overloadedBytes); + largeMessageDroppedTasksDueToTimeout = Metrics.register(factory.createMetricName("LargeMessageDroppedTasksDueToTimeout"), messagingPool.large::expiredCount); + largeMessageDroppedBytesDueToTimeout = Metrics.register(factory.createMetricName("LargeMessageDroppedBytesDueToTimeout"), messagingPool.large::expiredBytes); + largeMessageDroppedTasksDueToError = Metrics.register(factory.createMetricName("LargeMessageDroppedTasksDueToError"), messagingPool.large::errorCount); + largeMessageDroppedBytesDueToError = Metrics.register(factory.createMetricName("LargeMessageDroppedBytesDueToError"), messagingPool.large::errorBytes); + smallMessagePendingTasks = Metrics.register(factory.createMetricName("SmallMessagePendingTasks"), messagingPool.small::pendingCount); + smallMessagePendingBytes = Metrics.register(factory.createMetricName("SmallMessagePendingBytes"), messagingPool.small::pendingBytes); + smallMessageCompletedTasks = Metrics.register(factory.createMetricName("SmallMessageCompletedTasks"), messagingPool.small::sentCount); + smallMessageCompletedBytes = Metrics.register(factory.createMetricName("SmallMessageCompletedBytes"),messagingPool.small::sentBytes); + smallMessageDroppedTasks = Metrics.register(factory.createMetricName("SmallMessageDroppedTasks"), messagingPool.small::dropped); + smallMessageDroppedTasksDueToOverload = Metrics.register(factory.createMetricName("SmallMessageDroppedTasksDueToOverload"), messagingPool.small::overloadedCount); + smallMessageDroppedBytesDueToOverload = Metrics.register(factory.createMetricName("SmallMessageDroppedBytesDueToOverload"), messagingPool.small::overloadedBytes); + smallMessageDroppedTasksDueToTimeout = Metrics.register(factory.createMetricName("SmallMessageDroppedTasksDueToTimeout"), messagingPool.small::expiredCount); + smallMessageDroppedBytesDueToTimeout = Metrics.register(factory.createMetricName("SmallMessageDroppedBytesDueToTimeout"), messagingPool.small::expiredBytes); + smallMessageDroppedTasksDueToError = Metrics.register(factory.createMetricName("SmallMessageDroppedTasksDueToError"), messagingPool.small::errorCount); + smallMessageDroppedBytesDueToError = Metrics.register(factory.createMetricName("SmallMessageDroppedBytesDueToError"), messagingPool.small::errorBytes); + urgentMessagePendingTasks = Metrics.register(factory.createMetricName("UrgentMessagePendingTasks"), messagingPool.urgent::pendingCount); + urgentMessagePendingBytes = Metrics.register(factory.createMetricName("UrgentMessagePendingBytes"), messagingPool.urgent::pendingBytes); + urgentMessageCompletedTasks = Metrics.register(factory.createMetricName("UrgentMessageCompletedTasks"), messagingPool.urgent::sentCount); + urgentMessageCompletedBytes = Metrics.register(factory.createMetricName("UrgentMessageCompletedBytes"),messagingPool.urgent::sentBytes); + urgentMessageDroppedTasks = Metrics.register(factory.createMetricName("UrgentMessageDroppedTasks"), messagingPool.urgent::dropped); + urgentMessageDroppedTasksDueToOverload = Metrics.register(factory.createMetricName("UrgentMessageDroppedTasksDueToOverload"), messagingPool.urgent::overloadedCount); + urgentMessageDroppedBytesDueToOverload = Metrics.register(factory.createMetricName("UrgentMessageDroppedBytesDueToOverload"), messagingPool.urgent::overloadedBytes); + urgentMessageDroppedTasksDueToTimeout = Metrics.register(factory.createMetricName("UrgentMessageDroppedTasksDueToTimeout"), messagingPool.urgent::expiredCount); + urgentMessageDroppedBytesDueToTimeout = Metrics.register(factory.createMetricName("UrgentMessageDroppedBytesDueToTimeout"), messagingPool.urgent::expiredBytes); + urgentMessageDroppedTasksDueToError = Metrics.register(factory.createMetricName("UrgentMessageDroppedTasksDueToError"), messagingPool.urgent::errorCount); + urgentMessageDroppedBytesDueToError = Metrics.register(factory.createMetricName("UrgentMessageDroppedBytesDueToError"), messagingPool.urgent::errorBytes); + expiredCallbacks = Metrics.meter(factory.createMetricName("Timeouts")); + + // deprecated + Metrics.register(factory.createMetricName("GossipMessagePendingTasks"), (Gauge) messagingPool.urgent::pendingCount); + Metrics.register(factory.createMetricName("GossipMessageCompletedTasks"), (Gauge) messagingPool.urgent::sentCount); + Metrics.register(factory.createMetricName("GossipMessageDroppedTasks"), (Gauge) messagingPool.urgent::dropped); + } + + public void release() + { + Metrics.remove(factory.createMetricName("LargeMessagePendingTasks")); + Metrics.remove(factory.createMetricName("LargeMessagePendingBytes")); + Metrics.remove(factory.createMetricName("LargeMessageCompletedTasks")); + Metrics.remove(factory.createMetricName("LargeMessageCompletedBytes")); + Metrics.remove(factory.createMetricName("LargeMessageDroppedTasks")); + Metrics.remove(factory.createMetricName("LargeMessageDroppedTasksDueToTimeout")); + Metrics.remove(factory.createMetricName("LargeMessageDroppedBytesDueToTimeout")); + Metrics.remove(factory.createMetricName("LargeMessageDroppedTasksDueToOverload")); + Metrics.remove(factory.createMetricName("LargeMessageDroppedBytesDueToOverload")); + Metrics.remove(factory.createMetricName("LargeMessageDroppedTasksDueToError")); + Metrics.remove(factory.createMetricName("LargeMessageDroppedBytesDueToError")); + Metrics.remove(factory.createMetricName("SmallMessagePendingTasks")); + Metrics.remove(factory.createMetricName("SmallMessagePendingBytes")); + Metrics.remove(factory.createMetricName("SmallMessageCompletedTasks")); + Metrics.remove(factory.createMetricName("SmallMessageCompletedBytes")); + Metrics.remove(factory.createMetricName("SmallMessageDroppedTasks")); + Metrics.remove(factory.createMetricName("SmallMessageDroppedTasksDueToTimeout")); + Metrics.remove(factory.createMetricName("SmallMessageDroppedBytesDueToTimeout")); + Metrics.remove(factory.createMetricName("SmallMessageDroppedTasksDueToOverload")); + Metrics.remove(factory.createMetricName("SmallMessageDroppedBytesDueToOverload")); + Metrics.remove(factory.createMetricName("SmallMessageDroppedTasksDueToError")); + Metrics.remove(factory.createMetricName("SmallMessageDroppedBytesDueToError")); + Metrics.remove(factory.createMetricName("GossipMessagePendingTasks")); + Metrics.remove(factory.createMetricName("GossipMessageCompletedTasks")); + Metrics.remove(factory.createMetricName("GossipMessageDroppedTasks")); + Metrics.remove(factory.createMetricName("UrgentMessagePendingTasks")); + Metrics.remove(factory.createMetricName("UrgentMessagePendingBytes")); + Metrics.remove(factory.createMetricName("UrgentMessageCompletedTasks")); + Metrics.remove(factory.createMetricName("UrgentMessageCompletedBytes")); + Metrics.remove(factory.createMetricName("UrgentMessageDroppedTasks")); + Metrics.remove(factory.createMetricName("UrgentMessageDroppedTasksDueToTimeout")); + Metrics.remove(factory.createMetricName("UrgentMessageDroppedBytesDueToTimeout")); + Metrics.remove(factory.createMetricName("UrgentMessageDroppedTasksDueToOverload")); + Metrics.remove(factory.createMetricName("UrgentMessageDroppedBytesDueToOverload")); + Metrics.remove(factory.createMetricName("UrgentMessageDroppedTasksDueToError")); + Metrics.remove(factory.createMetricName("UrgentMessageDroppedBytesDueToError")); + Metrics.remove(factory.createMetricName("Timeouts")); + } +} diff --git a/src/java/org/apache/cassandra/metrics/MaxSampler.java b/src/java/org/apache/cassandra/metrics/MaxSampler.java index f4fb87351207..df24bb96298b 100644 --- a/src/java/org/apache/cassandra/metrics/MaxSampler.java +++ b/src/java/org/apache/cassandra/metrics/MaxSampler.java @@ -24,23 +24,25 @@ import com.google.common.collect.MinMaxPriorityQueue; +import static java.util.concurrent.TimeUnit.MILLISECONDS; + public abstract class MaxSampler extends Sampler { private int capacity; private MinMaxPriorityQueue> queue; - private long endTimeMillis = -1; + private long endTimeNanos = -1; private final Comparator> comp = Collections.reverseOrder(Comparator.comparing(p -> p.count)); public boolean isEnabled() { - return endTimeMillis != -1 && clock.currentTimeMillis() <= endTimeMillis; + return endTimeNanos != -1 && clock.now() <= endTimeNanos; } public synchronized void beginSampling(int capacity, int durationMillis) { - if (endTimeMillis == -1 || clock.currentTimeMillis() > endTimeMillis) + if (endTimeNanos == -1 || clock.now() > endTimeNanos) { - endTimeMillis = clock.currentTimeMillis() + durationMillis; + endTimeNanos = clock.now() + MILLISECONDS.toNanos(durationMillis); queue = MinMaxPriorityQueue .orderedBy(comp) .maximumSize(Math.max(1, capacity)) @@ -54,9 +56,9 @@ public synchronized void beginSampling(int capacity, int durationMillis) public synchronized List> finishSampling(int count) { List> result = new ArrayList<>(count); - if (endTimeMillis != -1) + if (endTimeNanos != -1) { - endTimeMillis = -1; + endTimeNanos = -1; Sample next; while ((next = queue.poll()) != null && result.size() <= count) result.add(next); @@ -67,7 +69,7 @@ public synchronized List> finishSampling(int count) @Override protected synchronized void insert(T item, long value) { - if (value > 0 && clock.currentTimeMillis() <= endTimeMillis + if (value > 0 && clock.now() <= endTimeNanos && (queue.isEmpty() || queue.size() < capacity || queue.peekLast().count < value)) queue.add(new Sample(item, value, 0)); } diff --git a/src/java/org/apache/cassandra/metrics/MessagingMetrics.java b/src/java/org/apache/cassandra/metrics/MessagingMetrics.java index 2f096f6c371c..0ea2e10ccb8d 100644 --- a/src/java/org/apache/cassandra/metrics/MessagingMetrics.java +++ b/src/java/org/apache/cassandra/metrics/MessagingMetrics.java @@ -17,59 +17,215 @@ */ package org.apache.cassandra.metrics; +import java.util.EnumMap; +import java.util.HashMap; +import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Consumer; -import org.apache.cassandra.config.DatabaseDescriptor; +import com.google.common.annotations.VisibleForTesting; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.apache.cassandra.concurrent.ScheduledExecutors; +import org.apache.cassandra.config.DatabaseDescriptor; + import com.codahale.metrics.Timer; import org.apache.cassandra.locator.InetAddressAndPort; +import org.apache.cassandra.net.Message; +import org.apache.cassandra.net.Verb; +import org.apache.cassandra.net.InboundMessageHandlers; +import org.apache.cassandra.net.LatencyConsumer; +import org.apache.cassandra.utils.StatusLogger; +import static java.util.concurrent.TimeUnit.MILLISECONDS; import static org.apache.cassandra.metrics.CassandraMetricsRegistry.Metrics; /** * Metrics for messages */ -public class MessagingMetrics +public class MessagingMetrics implements InboundMessageHandlers.GlobalMetricCallbacks { - private static Logger logger = LoggerFactory.getLogger(MessagingMetrics.class); private static final MetricNameFactory factory = new DefaultNameFactory("Messaging"); - public final Timer crossNodeLatency; - public final ConcurrentHashMap dcLatency; - public final ConcurrentHashMap queueWaitLatency; + private static final Logger logger = LoggerFactory.getLogger(MessagingMetrics.class); + private static final int LOG_DROPPED_INTERVAL_IN_MS = 5000; + + public static class DCLatencyRecorder implements LatencyConsumer + { + public final Timer dcLatency; + public final Timer allLatency; + + DCLatencyRecorder(Timer dcLatency, Timer allLatency) + { + this.dcLatency = dcLatency; + this.allLatency = allLatency; + } + + public void accept(long timeTaken, TimeUnit units) + { + if (timeTaken > 0) + { + dcLatency.update(timeTaken, units); + allLatency.update(timeTaken, units); + } + } + } + + private static final class DroppedForVerb + { + final DroppedMessageMetrics metrics; + final AtomicInteger droppedFromSelf; + final AtomicInteger droppedFromPeer; + + DroppedForVerb(Verb verb) + { + this(new DroppedMessageMetrics(verb)); + } + + DroppedForVerb(DroppedMessageMetrics metrics) + { + this.metrics = metrics; + this.droppedFromSelf = new AtomicInteger(0); + this.droppedFromPeer = new AtomicInteger(0); + } + } + + private final Timer allLatency; + public final ConcurrentHashMap dcLatency; + public final EnumMap internalLatency; + + // total dropped message counts for server lifetime + private final Map droppedMessages = new EnumMap<>(Verb.class); public MessagingMetrics() { - crossNodeLatency = Metrics.timer(factory.createMetricName("CrossNodeLatency")); + allLatency = Metrics.timer(factory.createMetricName("CrossNodeLatency")); dcLatency = new ConcurrentHashMap<>(); - queueWaitLatency = new ConcurrentHashMap<>(); + internalLatency = new EnumMap<>(Verb.class); + for (Verb verb : Verb.VERBS) + internalLatency.put(verb, Metrics.timer(factory.createMetricName(verb + "-WaitLatency"))); + for (Verb verb : Verb.values()) + droppedMessages.put(verb, new DroppedForVerb(verb)); + } + + public DCLatencyRecorder internodeLatencyRecorder(InetAddressAndPort from) + { + String dcName = DatabaseDescriptor.getEndpointSnitch().getDatacenter(from); + DCLatencyRecorder dcUpdater = dcLatency.get(dcName); + if (dcUpdater == null) + dcUpdater = dcLatency.computeIfAbsent(dcName, k -> new DCLatencyRecorder(Metrics.timer(factory.createMetricName(dcName + "-Latency")), allLatency)); + return dcUpdater; + } + + public void recordInternalLatency(Verb verb, long timeTaken, TimeUnit units) + { + if (timeTaken > 0) + internalLatency.get(verb).update(timeTaken, units); + } + + public void recordSelfDroppedMessage(Verb verb) + { + recordDroppedMessage(droppedMessages.get(verb), false); } - public void addTimeTaken(InetAddressAndPort from, long timeTaken) + public void recordSelfDroppedMessage(Verb verb, long timeElapsed, TimeUnit timeUnit) { - String dc = DatabaseDescriptor.getEndpointSnitch().getDatacenter(from); - Timer timer = dcLatency.get(dc); - if (timer == null) + recordDroppedMessage(verb, timeElapsed, timeUnit, false); + } + + public void recordInternodeDroppedMessage(Verb verb, long timeElapsed, TimeUnit timeUnit) + { + recordDroppedMessage(verb, timeElapsed, timeUnit, true); + } + + public void recordDroppedMessage(Message message, long timeElapsed, TimeUnit timeUnit) + { + recordDroppedMessage(message.verb(), timeElapsed, timeUnit, message.isCrossNode()); + } + + public void recordDroppedMessage(Verb verb, long timeElapsed, TimeUnit timeUnit, boolean isCrossNode) + { + recordDroppedMessage(droppedMessages.get(verb), timeElapsed, timeUnit, isCrossNode); + } + + private static void recordDroppedMessage(DroppedForVerb droppedMessages, long timeTaken, TimeUnit units, boolean isCrossNode) + { + if (isCrossNode) + droppedMessages.metrics.crossNodeDroppedLatency.update(timeTaken, units); + else + droppedMessages.metrics.internalDroppedLatency.update(timeTaken, units); + recordDroppedMessage(droppedMessages, isCrossNode); + } + + private static void recordDroppedMessage(DroppedForVerb droppedMessages, boolean isCrossNode) + { + droppedMessages.metrics.dropped.mark(); + if (isCrossNode) + droppedMessages.droppedFromPeer.incrementAndGet(); + else + droppedMessages.droppedFromSelf.incrementAndGet(); + } + + public void scheduleLogging() + { + ScheduledExecutors.scheduledTasks.scheduleWithFixedDelay(this::logDroppedMessages, + LOG_DROPPED_INTERVAL_IN_MS, + LOG_DROPPED_INTERVAL_IN_MS, + MILLISECONDS); + } + + public Map getDroppedMessages() + { + Map map = new HashMap<>(droppedMessages.size()); + for (Map.Entry entry : droppedMessages.entrySet()) + map.put(entry.getKey().toString(), (int) entry.getValue().metrics.dropped.getCount()); + return map; + } + + private void logDroppedMessages() + { + if (resetAndConsumeDroppedErrors(logger::info) > 0) + StatusLogger.log(); + } + + @VisibleForTesting + public int resetAndConsumeDroppedErrors(Consumer messageConsumer) + { + int count = 0; + for (Map.Entry entry : droppedMessages.entrySet()) { - timer = dcLatency.computeIfAbsent(dc, k -> Metrics.timer(factory.createMetricName(dc + "-Latency"))); + Verb verb = entry.getKey(); + DroppedForVerb droppedForVerb = entry.getValue(); + + int droppedInternal = droppedForVerb.droppedFromSelf.getAndSet(0); + int droppedCrossNode = droppedForVerb.droppedFromPeer.getAndSet(0); + if (droppedInternal > 0 || droppedCrossNode > 0) + { + messageConsumer.accept(String.format("%s messages were dropped in last %d ms: %d internal and %d cross node." + + " Mean internal dropped latency: %d ms and Mean cross-node dropped latency: %d ms", + verb, + LOG_DROPPED_INTERVAL_IN_MS, + droppedInternal, + droppedCrossNode, + TimeUnit.NANOSECONDS.toMillis((long) droppedForVerb.metrics.internalDroppedLatency.getSnapshot().getMean()), + TimeUnit.NANOSECONDS.toMillis((long) droppedForVerb.metrics.crossNodeDroppedLatency.getSnapshot().getMean()))); + ++count; + } } - timer.update(timeTaken, TimeUnit.MILLISECONDS); - crossNodeLatency.update(timeTaken, TimeUnit.MILLISECONDS); + return count; } - public void addQueueWaitTime(String verb, long timeTaken) + @VisibleForTesting + public void resetDroppedMessages(String scope) { - if (timeTaken < 0) - // the measurement is not accurate, ignore the negative timeTaken - return; - - Timer timer = queueWaitLatency.get(verb); - if (timer == null) + for (Verb verb : droppedMessages.keySet()) { - timer = queueWaitLatency.computeIfAbsent(verb, k -> Metrics.timer(factory.createMetricName(verb + "-WaitLatency"))); + droppedMessages.put(verb, new DroppedForVerb(new DroppedMessageMetrics(metricName -> + new CassandraMetricsRegistry.MetricName("DroppedMessages", metricName, scope) + ))); } - timer.update(timeTaken, TimeUnit.MILLISECONDS); } + } diff --git a/src/java/org/apache/cassandra/metrics/Sampler.java b/src/java/org/apache/cassandra/metrics/Sampler.java index 4bff332eeaab..cfe3f3b24943 100644 --- a/src/java/org/apache/cassandra/metrics/Sampler.java +++ b/src/java/org/apache/cassandra/metrics/Sampler.java @@ -26,8 +26,8 @@ import org.apache.cassandra.concurrent.JMXEnabledThreadPoolExecutor; import org.apache.cassandra.concurrent.NamedThreadFactory; import org.apache.cassandra.net.MessagingService; -import org.apache.cassandra.net.MessagingService.Verb; -import org.apache.cassandra.utils.Clock; +import org.apache.cassandra.net.Verb; +import org.apache.cassandra.utils.MonotonicClock; import com.google.common.annotations.VisibleForTesting; @@ -39,7 +39,7 @@ public enum SamplerType } @VisibleForTesting - Clock clock = Clock.instance; + MonotonicClock clock = MonotonicClock.approxTime; @VisibleForTesting static final ThreadPoolExecutor samplerExecutor = new JMXEnabledThreadPoolExecutor(1, 1, @@ -52,7 +52,7 @@ public enum SamplerType { samplerExecutor.setRejectedExecutionHandler((runnable, executor) -> { - MessagingService.instance().incrementDroppedMessages(Verb._SAMPLE); + MessagingService.instance().metrics.recordSelfDroppedMessage(Verb._SAMPLE); }); } diff --git a/src/java/org/apache/cassandra/metrics/TableMetrics.java b/src/java/org/apache/cassandra/metrics/TableMetrics.java index c854c43c7cfc..d8330cc2f03c 100644 --- a/src/java/org/apache/cassandra/metrics/TableMetrics.java +++ b/src/java/org/apache/cassandra/metrics/TableMetrics.java @@ -202,7 +202,7 @@ public class TableMetrics /** Time spent waiting for free memtable space, either on- or off-heap */ public final Histogram waitingOnFreeMemtableSpace; - /** Dropped Mutations Count */ + @Deprecated public final Counter droppedMutations; private final MetricNameFactory factory; @@ -523,7 +523,7 @@ public long[] getValue() { public EstimatedHistogram getHistogram(SSTableReader reader) { - return reader.getEstimatedColumnCount(); + return reader.getEstimatedCellPerPartitionCount(); } }); } diff --git a/src/java/org/apache/cassandra/net/IMessageSink.java b/src/java/org/apache/cassandra/net/AcceptVersions.java similarity index 62% rename from src/java/org/apache/cassandra/net/IMessageSink.java rename to src/java/org/apache/cassandra/net/AcceptVersions.java index 090d2c21cd91..61ae0491a18c 100644 --- a/src/java/org/apache/cassandra/net/IMessageSink.java +++ b/src/java/org/apache/cassandra/net/AcceptVersions.java @@ -17,21 +17,26 @@ */ package org.apache.cassandra.net; -import org.apache.cassandra.locator.InetAddressAndPort; - -public interface IMessageSink +/** + * Encapsulates minimum and maximum messaging versions that a node accepts. + */ +class AcceptVersions { - /** - * Allow or drop an outgoing message - * - * @return true if the message is allowed, false if it should be dropped - */ - boolean allowOutgoingMessage(MessageOut message, int id, InetAddressAndPort to); + final int min, max; + + AcceptVersions(int min, int max) + { + this.min = min; + this.max = max; + } + + @Override + public boolean equals(Object that) + { + if (!(that instanceof AcceptVersions)) + return false; - /** - * Allow or drop an incoming message - * - * @return true if the message is allowed, false if it should be dropped - */ - boolean allowIncomingMessage(MessageIn message, int id); + return min == ((AcceptVersions) that).min + && max == ((AcceptVersions) that).max; + } } diff --git a/src/java/org/apache/cassandra/net/AsyncChannelOutputPlus.java b/src/java/org/apache/cassandra/net/AsyncChannelOutputPlus.java new file mode 100644 index 000000000000..163981c901d3 --- /dev/null +++ b/src/java/org/apache/cassandra/net/AsyncChannelOutputPlus.java @@ -0,0 +1,268 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.cassandra.net; + +import java.io.IOException; +import java.nio.channels.WritableByteChannel; +import java.util.concurrent.locks.LockSupport; + +import io.netty.channel.Channel; +import io.netty.channel.ChannelPromise; +import org.apache.cassandra.io.util.BufferedDataOutputStreamPlus; +import org.apache.cassandra.io.util.DataOutputStreamPlus; + +import static java.lang.Math.max; + +/** + * A {@link DataOutputStreamPlus} that writes ASYNCHRONOUSLY to a Netty Channel. + * + * The close() and flush() methods synchronously wait for pending writes, and will propagate any exceptions + * encountered in writing them to the wire. + * + * The correctness of this class depends on the ChannelPromise we create against a Channel always being completed, + * which appears to be a guarantee provided by Netty so long as the event loop is running. + * + * There are two logical threads accessing the state in this class: the eventLoop of the channel, and the writer + * (the writer thread may change, so long as only one utilises the class at any time). + * Each thread has exclusive write access to certain state in the class, with the other thread only viewing the state, + * simplifying concurrency considerations. + */ +public abstract class AsyncChannelOutputPlus extends BufferedDataOutputStreamPlus +{ + public static class FlushException extends IOException + { + public FlushException(String message) + { + super(message); + } + + public FlushException(String message, Throwable cause) + { + super(message, cause); + } + } + + final Channel channel; + + /** the number of bytes we have begun flushing; updated only by writer */ + private volatile long flushing; + /** the number of bytes we have finished flushing, successfully or otherwise; updated only by eventLoop */ + private volatile long flushed; + /** the number of bytes we have finished flushing to the network; updated only by eventLoop */ + private long flushedToNetwork; + /** any error that has been thrown during a flush; updated only by eventLoop */ + private volatile Throwable flushFailed; + + /** + * state for pausing until flushing has caught up - store the number of bytes we need to be flushed before + * we should be signalled, and store ourselves in {@link #waiting}; once the flushing thread exceeds this many + * total bytes flushed, any Thread stored in waiting will be signalled. + * + * This works exactly like using a WaitQueue, except that we only need to manage a single waiting thread. + */ + private volatile long signalWhenFlushed; // updated only by writer + private volatile Thread waiting; // updated only by writer + + public AsyncChannelOutputPlus(Channel channel) + { + super(null, null); + this.channel = channel; + } + + /** + * Create a ChannelPromise for a flush of the given size. + *

+ * This method will not return until the write is permitted by the provided watermarks and in flight bytes, + * and on its completion will mark the requested bytes flushed. + *

+ * If this method returns normally, the ChannelPromise MUST be writtenAndFlushed, or else completed exceptionally. + */ + protected ChannelPromise beginFlush(int byteCount, int lowWaterMark, int highWaterMark) throws IOException + { + waitForSpace(byteCount, lowWaterMark, highWaterMark); + + return AsyncChannelPromise.withListener(channel, future -> { + if (future.isSuccess() && null == flushFailed) + { + flushedToNetwork += byteCount; + releaseSpace(byteCount); + } + else if (null == flushFailed) + { + Throwable cause = future.cause(); + if (cause == null) + { + cause = new FlushException("Flush failed for unknown reason"); + cause.fillInStackTrace(); + } + flushFailed = cause; + releaseSpace(flushing - flushed); + } + else + { + assert flushing == flushed; + } + }); + } + + /** + * Imposes our lowWaterMark/highWaterMark constraints, and propagates any exceptions thrown by prior flushes. + * + * If we currently have lowWaterMark or fewer bytes flushing, we are good to go. + * If our new write will not take us over our highWaterMark, we are good to go. + * Otherwise we wait until either of these conditions are met. + * + * This may only be invoked by the writer thread, never by the eventLoop. + * + * @throws IOException if a prior asynchronous flush failed + */ + private void waitForSpace(int bytesToWrite, int lowWaterMark, int highWaterMark) throws IOException + { + // decide when we would be willing to carry on writing + // we are always writable if we have lowWaterMark or fewer bytes, no matter how many bytes we are flushing + // our callers should not be supplying more than (highWaterMark - lowWaterMark) bytes, but we must work correctly if they do + int wakeUpWhenFlushing = highWaterMark - bytesToWrite; + waitUntilFlushed(max(lowWaterMark, wakeUpWhenFlushing), lowWaterMark); + flushing += bytesToWrite; + } + + /** + * Implementation of waitForSpace, which calculates what flushed points we need to wait for, + * parks if necessary and propagates flush failures. + * + * This may only be invoked by the writer thread, never by the eventLoop. + */ + void waitUntilFlushed(int wakeUpWhenExcessBytesWritten, int signalWhenExcessBytesWritten) throws IOException + { + // we assume that we are happy to wake up at least as early as we will be signalled; otherwise we will never exit + assert signalWhenExcessBytesWritten <= wakeUpWhenExcessBytesWritten; + // flushing shouldn't change during this method invocation, so our calculations for signal and flushed are consistent + long wakeUpWhenFlushed = flushing - wakeUpWhenExcessBytesWritten; + if (flushed < wakeUpWhenFlushed) + parkUntilFlushed(wakeUpWhenFlushed, flushing - signalWhenExcessBytesWritten); + propagateFailedFlush(); + } + + /** + * Utility method for waitUntilFlushed, which actually parks the current thread until the necessary + * number of bytes have been flushed + * + * This may only be invoked by the writer thread, never by the eventLoop. + */ + protected void parkUntilFlushed(long wakeUpWhenFlushed, long signalWhenFlushed) + { + assert wakeUpWhenFlushed <= signalWhenFlushed; + assert waiting == null; + this.waiting = Thread.currentThread(); + this.signalWhenFlushed = signalWhenFlushed; + + while (flushed < wakeUpWhenFlushed) + LockSupport.park(); + waiting = null; + } + + /** + * Update our flushed count, and signal any waiters. + * + * This may only be invoked by the eventLoop, never by the writer thread. + */ + protected void releaseSpace(long bytesFlushed) + { + long newFlushed = flushed + bytesFlushed; + flushed = newFlushed; + + Thread thread = waiting; + if (thread != null && signalWhenFlushed <= newFlushed) + LockSupport.unpark(thread); + } + + private void propagateFailedFlush() throws IOException + { + Throwable t = flushFailed; + if (t != null) + { + if (SocketFactory.isCausedByConnectionReset(t)) + throw new FlushException("The channel this output stream was writing to has been closed", t); + throw new FlushException("This output stream is in an unsafe state after an asynchronous flush failed", t); + } + } + + @Override + abstract protected void doFlush(int count) throws IOException; + + abstract public long position(); + + public long flushed() + { + // external flushed (that which has had flush() invoked implicitly or otherwise) == internal flushing + return flushing; + } + + public long flushedToNetwork() + { + return flushedToNetwork; + } + + /** + * Perform an asynchronous flush, then waits until all outstanding flushes have completed + * + * @throws IOException if any flush fails + */ + @Override + public void flush() throws IOException + { + doFlush(0); + waitUntilFlushed(0, 0); + } + + /** + * Flush any remaining writes, and release any buffers. + * + * The channel is not closed, as it is assumed to be managed externally. + * + * WARNING: This method requires mutual exclusivity with all other producer methods to run safely. + * It should only be invoked by the owning thread, never the eventLoop; the eventLoop should propagate + * errors to {@link #flushFailed}, which will propagate them to the producer thread no later than its + * final invocation to {@link #close()} or {@link #flush()} (that must not be followed by any further writes). + */ + @Override + public void close() throws IOException + { + try + { + flush(); + } + finally + { + discard(); + } + } + + /** + * Discard any buffered data, and the buffers that contain it. + * May be invoked instead of {@link #close()} if we terminate exceptionally. + */ + public abstract void discard(); + + @Override + protected WritableByteChannel newDefaultChannel() + { + throw new UnsupportedOperationException(); + } + +} \ No newline at end of file diff --git a/src/java/org/apache/cassandra/net/AsyncChannelPromise.java b/src/java/org/apache/cassandra/net/AsyncChannelPromise.java new file mode 100644 index 000000000000..d2c9d0bfb6de --- /dev/null +++ b/src/java/org/apache/cassandra/net/AsyncChannelPromise.java @@ -0,0 +1,164 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.net; + +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPromise; +import io.netty.util.concurrent.Future; +import io.netty.util.concurrent.GenericFutureListener; + +/** + * See {@link AsyncPromise} and {@link io.netty.channel.ChannelPromise} + * + * This class is all boiler plate, just ensuring we return ourselves and invoke the correct Promise method. + */ +public class AsyncChannelPromise extends AsyncPromise implements ChannelPromise +{ + private final Channel channel; + + @SuppressWarnings("unused") + public AsyncChannelPromise(Channel channel) + { + super(channel.eventLoop()); + this.channel = channel; + } + + AsyncChannelPromise(Channel channel, GenericFutureListener> listener) + { + super(channel.eventLoop(), listener); + this.channel = channel; + } + + public static AsyncChannelPromise withListener(ChannelHandlerContext context, GenericFutureListener> listener) + { + return withListener(context.channel(), listener); + } + + public static AsyncChannelPromise withListener(Channel channel, GenericFutureListener> listener) + { + return new AsyncChannelPromise(channel, listener); + } + + public static ChannelFuture writeAndFlush(ChannelHandlerContext context, Object message, GenericFutureListener> listener) + { + return context.writeAndFlush(message, withListener(context.channel(), listener)); + } + + public static ChannelFuture writeAndFlush(Channel channel, Object message, GenericFutureListener> listener) + { + return channel.writeAndFlush(message, withListener(channel, listener)); + } + + public static ChannelFuture writeAndFlush(ChannelHandlerContext context, Object message) + { + return context.writeAndFlush(message, new AsyncChannelPromise(context.channel())); + } + + public static ChannelFuture writeAndFlush(Channel channel, Object message) + { + return channel.writeAndFlush(message, new AsyncChannelPromise(channel)); + } + + public Channel channel() + { + return channel; + } + + public boolean isVoid() + { + return false; + } + + public ChannelPromise setSuccess() + { + return setSuccess(null); + } + + public ChannelPromise setSuccess(Void v) + { + super.setSuccess(v); + return this; + } + + public boolean trySuccess() + { + return trySuccess(null); + } + + public ChannelPromise setFailure(Throwable throwable) + { + super.setFailure(throwable); + return this; + } + + public ChannelPromise sync() throws InterruptedException + { + super.sync(); + return this; + } + + public ChannelPromise syncUninterruptibly() + { + super.syncUninterruptibly(); + return this; + } + + public ChannelPromise await() throws InterruptedException + { + super.await(); + return this; + } + + public ChannelPromise awaitUninterruptibly() + { + super.awaitUninterruptibly(); + return this; + } + + public ChannelPromise addListener(GenericFutureListener> listener) + { + super.addListener(listener); + return this; + } + + public ChannelPromise addListeners(GenericFutureListener>... listeners) + { + super.addListeners(listeners); + return this; + } + + public ChannelPromise removeListener(GenericFutureListener> listener) + { + super.removeListener(listener); + return this; + } + + public ChannelPromise removeListeners(GenericFutureListener>... listeners) + { + super.removeListeners(listeners); + return this; + } + + public ChannelPromise unvoid() + { + return this; + } +} diff --git a/src/java/org/apache/cassandra/net/AsyncMessageOutputPlus.java b/src/java/org/apache/cassandra/net/AsyncMessageOutputPlus.java new file mode 100644 index 000000000000..8ef0a8f0e631 --- /dev/null +++ b/src/java/org/apache/cassandra/net/AsyncMessageOutputPlus.java @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.net; + +import java.io.IOException; +import java.nio.channels.ClosedChannelException; + +import io.netty.channel.Channel; +import io.netty.channel.ChannelPromise; +import io.netty.channel.WriteBufferWaterMark; +import org.apache.cassandra.io.util.DataOutputStreamPlus; + +/** + * A {@link DataOutputStreamPlus} that writes ASYNCHRONOUSLY to a Netty Channel. + * + * Intended as single use, to write one (large) message. + * + * The close() and flush() methods synchronously wait for pending writes, and will propagate any exceptions + * encountered in writing them to the wire. + * + * The correctness of this class depends on the ChannelPromise we create against a Channel always being completed, + * which appears to be a guarantee provided by Netty so long as the event loop is running. + */ +public class AsyncMessageOutputPlus extends AsyncChannelOutputPlus +{ + /** + * the maximum {@link #highWaterMark} and minimum {@link #lowWaterMark} number of bytes we have flushing + * during which we should still be writing to the channel. + * + * i.e., if we are at or below the {@link #lowWaterMark} we should definitely start writing again; + * if we are at or above the {@link #highWaterMark} we should definitely stop writing; + * if we are inbetween, it is OK to either write or not write + * + * note that we consider the bytes we are about to write to our high water mark, but not our low. + * i.e., we will not begin a write that would take us over our high water mark, unless not doing so would + * take us below our low water mark. + * + * This is somewhat arbitrary accounting, and a meaningless distinction for flushes of a consistent size. + */ + @SuppressWarnings("JavaDoc") + private final int highWaterMark; + private final int lowWaterMark; + private final int bufferSize; + private final int messageSize; + private boolean closing; + + private final FrameEncoder.PayloadAllocator payloadAllocator; + private volatile FrameEncoder.Payload payload; + + AsyncMessageOutputPlus(Channel channel, int bufferSize, int messageSize, FrameEncoder.PayloadAllocator payloadAllocator) + { + super(channel); + WriteBufferWaterMark waterMark = channel.config().getWriteBufferWaterMark(); + this.lowWaterMark = waterMark.low(); + this.highWaterMark = waterMark.high(); + this.messageSize = messageSize; + this.bufferSize = Math.min(messageSize, bufferSize); + this.payloadAllocator = payloadAllocator; + allocateBuffer(); + } + + private void allocateBuffer() + { + payload = payloadAllocator.allocate(false, bufferSize); + buffer = payload.buffer; + } + + @Override + protected void doFlush(int count) throws IOException + { + if (!channel.isOpen()) + throw new ClosedChannelException(); + + // flush the current backing write buffer only if there's any pending data + FrameEncoder.Payload flush = payload; + int byteCount = flush.length(); + if (byteCount == 0) + return; + + if (byteCount + flushed() > (closing ? messageSize : messageSize - 1)) + throw new InvalidSerializedSizeException(messageSize, byteCount + flushed()); + + flush.finish(); + ChannelPromise promise = beginFlush(byteCount, lowWaterMark, highWaterMark); + channel.writeAndFlush(flush, promise); + allocateBuffer(); + } + + public void close() throws IOException + { + closing = true; + if (flushed() == 0 && payload != null) + payload.setSelfContained(true); + super.close(); + } + + public long position() + { + return flushed() + payload.length(); + } + + /** + * Discard any buffered data, and the buffers that contain it. + * May be invoked instead of {@link #close()} if we terminate exceptionally. + */ + public void discard() + { + if (payload != null) + { + payload.release(); + payload = null; + buffer = null; + } + } +} \ No newline at end of file diff --git a/src/java/org/apache/cassandra/net/AsyncOneResponse.java b/src/java/org/apache/cassandra/net/AsyncOneResponse.java index 3fe0a2aebbd6..ba83c84c91cb 100644 --- a/src/java/org/apache/cassandra/net/AsyncOneResponse.java +++ b/src/java/org/apache/cassandra/net/AsyncOneResponse.java @@ -17,54 +17,31 @@ */ package org.apache.cassandra.net; -import java.util.concurrent.ExecutionException; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.TimeoutException; - import com.google.common.annotations.VisibleForTesting; -import com.google.common.util.concurrent.AbstractFuture; + +import io.netty.util.concurrent.ImmediateEventExecutor; /** * A callback specialized for returning a value from a single target; that is, this is for messages * that we only send to one recipient. */ -public class AsyncOneResponse extends AbstractFuture implements IAsyncCallback +public class AsyncOneResponse extends AsyncPromise implements RequestCallback { - private final long start = System.nanoTime(); - - public void response(MessageIn response) - { - set(response.payload); - } - - public boolean isLatencyForSnitch() + public AsyncOneResponse() { - return false; + super(ImmediateEventExecutor.INSTANCE); } - @Override - public T get(long timeout, TimeUnit unit) throws TimeoutException + public void onResponse(Message response) { - long adjustedTimeout = unit.toNanos(timeout) - (System.nanoTime() - start); - if (adjustedTimeout <= 0) - { - throw new TimeoutException("Operation timed out."); - } - try - { - return super.get(adjustedTimeout, TimeUnit.NANOSECONDS); - } - catch (InterruptedException | ExecutionException e) - { - throw new AssertionError(e); - } + setSuccess(response.payload); } @VisibleForTesting public static AsyncOneResponse immediate(T value) { AsyncOneResponse response = new AsyncOneResponse<>(); - response.set(value); + response.setSuccess(value); return response; } } diff --git a/src/java/org/apache/cassandra/net/AsyncPromise.java b/src/java/org/apache/cassandra/net/AsyncPromise.java new file mode 100644 index 000000000000..36bc304a6b08 --- /dev/null +++ b/src/java/org/apache/cassandra/net/AsyncPromise.java @@ -0,0 +1,488 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.net; + +import java.util.concurrent.CancellationException; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import io.netty.util.concurrent.EventExecutor; +import io.netty.util.concurrent.Future; +import io.netty.util.concurrent.GenericFutureListener; +import io.netty.util.concurrent.Promise; +import io.netty.util.internal.PlatformDependent; +import io.netty.util.internal.ThrowableUtil; +import org.apache.cassandra.utils.concurrent.WaitQueue; + +import static java.util.concurrent.atomic.AtomicReferenceFieldUpdater.*; + +/** + * Netty's DefaultPromise uses a mutex to coordinate notifiers AND waiters between the eventLoop and the other threads. + * Since we register cross-thread listeners, this has the potential to block internode messaging for an unknown + * number of threads for an unknown period of time, if we are unlucky with the scheduler (which will certainly + * happen, just with some unknown but low periodicity) + * + * At the same time, we manage some other efficiencies: + * - We save some space when registering listeners, especially if there is only one listener, as we perform no + * extra allocations in this case. + * - We permit efficient initial state declaration, avoiding unnecessary CAS or lock acquisitions when mutating + * a Promise we are ourselves constructing (and can easily add more; only those we use have been added) + * + * We can also make some guarantees about our behaviour here, although we primarily mirror Netty. + * Specifically, we can guarantee that notifiers are always invoked in the order they are added (which may be true + * for netty, but was unclear and is not declared). This is useful for ensuring the correctness of some of our + * behaviours in OutboundConnection without having to jump through extra hoops. + * + * The implementation loosely follows that of Netty's DefaultPromise, with some slight changes; notably that we have + * no synchronisation on our listeners, instead using a CoW list that is cleared each time we notify listeners. + * + * We handle special values slightly differently. We do not use a special value for null, instead using + * a special value to indicate the result has not been set yet. This means that once isSuccess() holds, + * the result must be a correctly typed object (modulo generics pitfalls). + * All special values are also instances of FailureHolder, which simplifies a number of the logical conditions. + * + * @param + */ +public class AsyncPromise implements Promise +{ + private static final Logger logger = LoggerFactory.getLogger(AsyncPromise.class); + + private final EventExecutor executor; + private volatile Object result; + private volatile GenericFutureListener> listeners; + private volatile WaitQueue waiting; + private static final AtomicReferenceFieldUpdater resultUpdater = newUpdater(AsyncPromise.class, Object.class, "result"); + private static final AtomicReferenceFieldUpdater listenersUpdater = newUpdater(AsyncPromise.class, GenericFutureListener.class, "listeners"); + private static final AtomicReferenceFieldUpdater waitingUpdater = newUpdater(AsyncPromise.class, WaitQueue.class, "waiting"); + + private static final FailureHolder UNSET = new FailureHolder(null); + private static final FailureHolder UNCANCELLABLE = new FailureHolder(null); + private static final FailureHolder CANCELLED = new FailureHolder(ThrowableUtil.unknownStackTrace(new CancellationException(), AsyncPromise.class, "cancel(...)")); + + private static final DeferredGenericFutureListener NOTIFYING = future -> {}; + private static interface DeferredGenericFutureListener> extends GenericFutureListener {} + + private static final class FailureHolder + { + final Throwable cause; + private FailureHolder(Throwable cause) + { + this.cause = cause; + } + } + + public AsyncPromise(EventExecutor executor) + { + this(executor, UNSET); + } + + private AsyncPromise(EventExecutor executor, FailureHolder initialState) + { + this.executor = executor; + this.result = initialState; + } + + public AsyncPromise(EventExecutor executor, GenericFutureListener> listener) + { + this(executor); + this.listeners = listener; + } + + AsyncPromise(EventExecutor executor, FailureHolder initialState, GenericFutureListener> listener) + { + this(executor, initialState); + this.listeners = listener; + } + + public static AsyncPromise uncancellable(EventExecutor executor) + { + return new AsyncPromise<>(executor, UNCANCELLABLE); + } + + public static AsyncPromise uncancellable(EventExecutor executor, GenericFutureListener> listener) + { + return new AsyncPromise<>(executor, UNCANCELLABLE); + } + + public Promise setSuccess(V v) + { + if (!trySuccess(v)) + throw new IllegalStateException("complete already: " + this); + return this; + } + + public Promise setFailure(Throwable throwable) + { + if (!tryFailure(throwable)) + throw new IllegalStateException("complete already: " + this); + return this; + } + + public boolean trySuccess(V v) + { + return trySet(v); + } + + public boolean tryFailure(Throwable throwable) + { + return trySet(new FailureHolder(throwable)); + } + + public boolean setUncancellable() + { + if (trySet(UNCANCELLABLE)) + return true; + return result == UNCANCELLABLE; + } + + public boolean cancel(boolean b) + { + return trySet(CANCELLED); + } + + /** + * Shared implementation of various promise completion methods. + * Updates the result if it is possible to do so, returning success/failure. + * + * If the promise is UNSET the new value will succeed; + * if it is UNCANCELLABLE it will succeed only if the new value is not CANCELLED + * otherwise it will fail, as isDone() is implied + * + * If the update succeeds, and the new state implies isDone(), any listeners and waiters will be notified + */ + private boolean trySet(Object v) + { + while (true) + { + Object current = result; + if (isDone(current) || (current == UNCANCELLABLE && v == CANCELLED)) + return false; + if (resultUpdater.compareAndSet(this, current, v)) + { + if (v != UNCANCELLABLE) + { + notifyListeners(); + notifyWaiters(); + } + return true; + } + } + } + + public boolean isSuccess() + { + return isSuccess(result); + } + + private static boolean isSuccess(Object result) + { + return !(result instanceof FailureHolder); + } + + public boolean isCancelled() + { + return isCancelled(result); + } + + private static boolean isCancelled(Object result) + { + return result == CANCELLED; + } + + public boolean isDone() + { + return isDone(result); + } + + private static boolean isDone(Object result) + { + return result != UNSET && result != UNCANCELLABLE; + } + + public boolean isCancellable() + { + Object result = this.result; + return result == UNSET; + } + + public Throwable cause() + { + Object result = this.result; + if (result instanceof FailureHolder) + return ((FailureHolder) result).cause; + return null; + } + + /** + * if isSuccess(), returns the value, otherwise returns null + */ + @SuppressWarnings("unchecked") + public V getNow() + { + Object result = this.result; + if (isSuccess(result)) + return (V) result; + return null; + } + + public V get() throws InterruptedException, ExecutionException + { + await(); + return getWhenDone(); + } + + public V get(long timeout, TimeUnit unit) throws InterruptedException, ExecutionException, TimeoutException + { + if (!await(timeout, unit)) + throw new TimeoutException(); + return getWhenDone(); + } + + /** + * Shared implementation of get() after suitable await(); assumes isDone(), and returns + * either the success result or throws the suitable exception under failure + */ + @SuppressWarnings("unchecked") + private V getWhenDone() throws ExecutionException + { + Object result = this.result; + if (isSuccess(result)) + return (V) result; + if (result == CANCELLED) + throw new CancellationException(); + throw new ExecutionException(((FailureHolder) result).cause); + } + + /** + * waits for completion; in case of failure rethrows the original exception without a new wrapping exception + * so may cause problems for reporting stack traces + */ + public Promise sync() throws InterruptedException + { + await(); + rethrowIfFailed(); + return this; + } + + /** + * waits for completion; in case of failure rethrows the original exception without a new wrapping exception + * so may cause problems for reporting stack traces + */ + public Promise syncUninterruptibly() + { + awaitUninterruptibly(); + rethrowIfFailed(); + return this; + } + + private void rethrowIfFailed() + { + Throwable cause = this.cause(); + if (cause != null) + { + PlatformDependent.throwException(cause); + } + } + + public Promise addListener(GenericFutureListener> listener) + { + listenersUpdater.accumulateAndGet(this, listener, AsyncPromise::appendListener); + if (isDone()) + notifyListeners(); + return this; + } + + public Promise addListeners(GenericFutureListener> ... listeners) + { + // this could be more efficient if we cared, but we do not + return addListener(future -> { + for (GenericFutureListener> listener : listeners) + AsyncPromise.invokeListener((GenericFutureListener>)listener, future); + }); + } + + public Promise removeListener(GenericFutureListener> listener) + { + throw new UnsupportedOperationException(); + } + + public Promise removeListeners(GenericFutureListener> ... listeners) + { + throw new UnsupportedOperationException(); + } + + @SuppressWarnings("unchecked") + private void notifyListeners() + { + if (!executor.inEventLoop()) + { + // submit this method, to guarantee we invoke in the submitted order + executor.execute(this::notifyListeners); + return; + } + + if (listeners == null || listeners instanceof DeferredGenericFutureListener) + return; // either no listeners, or we are already notifying listeners, so we'll get to the new one when ready + + // first run our notifiers + while (true) + { + GenericFutureListener listeners = listenersUpdater.getAndSet(this, NOTIFYING); + if (listeners != null) + invokeListener(listeners, this); + + if (listenersUpdater.compareAndSet(this, NOTIFYING, null)) + return; + } + } + + private static > void invokeListener(GenericFutureListener listener, F future) + { + try + { + listener.operationComplete(future); + } + catch (Throwable t) + { + logger.error("Failed to invoke listener {} to {}", listener, future, t); + } + } + + private static > GenericFutureListener appendListener(GenericFutureListener prevListener, GenericFutureListener newListener) + { + GenericFutureListener result = newListener; + + if (prevListener != null && prevListener != NOTIFYING) + { + result = future -> { + invokeListener(prevListener, future); + // we will wrap the outer invocation with invokeListener, so no need to do it here too + newListener.operationComplete(future); + }; + } + + if (prevListener instanceof DeferredGenericFutureListener) + { + GenericFutureListener wrap = result; + result = (DeferredGenericFutureListener) wrap::operationComplete; + } + + return result; + } + + public Promise await() throws InterruptedException + { + await(0L, (signal, nanos) -> { signal.await(); return true; } ); + return this; + } + + public Promise awaitUninterruptibly() + { + await(0L, (signal, nanos) -> { signal.awaitUninterruptibly(); return true; } ); + return this; + } + + public boolean await(long timeout, TimeUnit unit) throws InterruptedException + { + return await(unit.toNanos(timeout), + (signal, nanos) -> signal.awaitUntil(nanos + System.nanoTime())); + } + + public boolean await(long timeoutMillis) throws InterruptedException + { + return await(timeoutMillis, TimeUnit.MILLISECONDS); + } + + public boolean awaitUninterruptibly(long timeout, TimeUnit unit) + { + return await(unit.toNanos(timeout), + (signal, nanos) -> signal.awaitUntilUninterruptibly(nanos + System.nanoTime())); + } + + public boolean awaitUninterruptibly(long timeoutMillis) + { + return awaitUninterruptibly(timeoutMillis, TimeUnit.MILLISECONDS); + } + + interface Awaiter + { + boolean await(WaitQueue.Signal value, long nanos) throws T; + } + + /** + * A clean way to implement each variant of await using lambdas; we permit a nanos parameter + * so that we can implement this without any unnecessary lambda allocations, although not + * all implementations need the nanos parameter (i.e. those that wait indefinitely) + */ + private boolean await(long nanos, Awaiter awaiter) throws T + { + if (isDone()) + return true; + + WaitQueue.Signal await = registerToWait(); + if (null != await) + return awaiter.await(await, nanos); + + return true; + } + + /** + * Register a signal that will be notified when the promise is completed; + * if the promise becomes completed before this signal is registered, null is returned + */ + private WaitQueue.Signal registerToWait() + { + WaitQueue waiting = this.waiting; + if (waiting == null && !waitingUpdater.compareAndSet(this, null, waiting = new WaitQueue())) + waiting = this.waiting; + assert waiting != null; + + WaitQueue.Signal signal = waiting.register(); + if (!isDone()) + return signal; + signal.cancel(); + return null; + } + + private void notifyWaiters() + { + WaitQueue waiting = this.waiting; + if (waiting != null) + waiting.signalAll(); + } + + public String toString() + { + Object result = this.result; + if (isSuccess(result)) + return "(success: " + result + ')'; + if (result == UNCANCELLABLE) + return "(uncancellable)"; + if (result == CANCELLED) + return "(cancelled)"; + if (isDone(result)) + return "(failure: " + ((FailureHolder) result).cause + ')'; + return "(incomplete)"; + } +} diff --git a/src/java/org/apache/cassandra/net/AsyncStreamingInputPlus.java b/src/java/org/apache/cassandra/net/AsyncStreamingInputPlus.java new file mode 100644 index 000000000000..84fb8ac167e2 --- /dev/null +++ b/src/java/org/apache/cassandra/net/AsyncStreamingInputPlus.java @@ -0,0 +1,251 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.cassandra.net; + +import java.io.EOFException; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.TimeUnit; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.primitives.Ints; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.Unpooled; +import io.netty.channel.Channel; +import org.apache.cassandra.io.util.RebufferingInputStream; + +// TODO: rewrite +public class AsyncStreamingInputPlus extends RebufferingInputStream +{ + public static class InputTimeoutException extends IOException + { + } + + private static final long DEFAULT_REBUFFER_BLOCK_IN_MILLIS = TimeUnit.MINUTES.toMillis(3); + + private final Channel channel; + + /** + * The parent, or owning, buffer of the current buffer being read from ({@link super#buffer}). + */ + private ByteBuf currentBuf; + + private final BlockingQueue queue; + + private final long rebufferTimeoutNanos; + + private volatile boolean isClosed; + + public AsyncStreamingInputPlus(Channel channel) + { + this(channel, DEFAULT_REBUFFER_BLOCK_IN_MILLIS, TimeUnit.MILLISECONDS); + } + + AsyncStreamingInputPlus(Channel channel, long rebufferTimeout, TimeUnit rebufferTimeoutUnit) + { + super(Unpooled.EMPTY_BUFFER.nioBuffer()); + currentBuf = Unpooled.EMPTY_BUFFER; + + queue = new LinkedBlockingQueue<>(); + rebufferTimeoutNanos = rebufferTimeoutUnit.toNanos(rebufferTimeout); + + this.channel = channel; + channel.config().setAutoRead(false); + } + + /** + * Append a {@link ByteBuf} to the end of the einternal queue. + * + * Note: it's expected this method is invoked on the netty event loop. + */ + public boolean append(ByteBuf buf) throws IllegalStateException + { + if (isClosed) return false; + + queue.add(buf); + + /* + * it's possible for append() to race with close(), so we need to ensure + * that the bytebuf gets released in that scenario + */ + if (isClosed) + while ((buf = queue.poll()) != null) + buf.release(); + + return true; + } + + /** + * {@inheritDoc} + * + * Release open buffers and poll the {@link #queue} for more data. + *

+ * This is best, and more or less expected, to be invoked on a consuming thread (not the event loop) + * becasue if we block on the queue we can't fill it on the event loop (as that's where the buffers are coming from). + * + * @throws EOFException when no further reading from this instance should occur. Implies this instance is closed. + * @throws InputTimeoutException when no new buffers arrive for reading before + * the {@link #rebufferTimeoutNanos} elapses while blocking. It's then not safe to reuse this instance again. + */ + @Override + protected void reBuffer() throws EOFException, InputTimeoutException + { + if (queue.isEmpty()) + channel.read(); + + currentBuf.release(); + currentBuf = null; + buffer = null; + + ByteBuf next = null; + try + { + next = queue.poll(rebufferTimeoutNanos, TimeUnit.NANOSECONDS); + } + catch (InterruptedException ie) + { + // nop + } + + if (null == next) + throw new InputTimeoutException(); + + if (next == Unpooled.EMPTY_BUFFER) // Unpooled.EMPTY_BUFFER is the indicator that the input is closed + throw new EOFException(); + + currentBuf = next; + buffer = next.nioBuffer(); + } + + public interface Consumer + { + int accept(ByteBuffer buffer) throws IOException; + } + + /** + * Consumes bytes in the stream until the given length + */ + public void consume(Consumer consumer, long length) throws IOException + { + while (length > 0) + { + if (!buffer.hasRemaining()) + reBuffer(); + + final int position = buffer.position(); + final int limit = buffer.limit(); + + buffer.limit(position + (int) Math.min(length, limit - position)); + try + { + int copied = consumer.accept(buffer); + buffer.position(position + copied); + length -= copied; + } + finally + { + buffer.limit(limit); + } + } + } + + /** + * {@inheritDoc} + * + * As long as this method is invoked on the consuming thread the returned value will be accurate. + */ + @VisibleForTesting + public int unsafeAvailable() + { + long count = buffer != null ? buffer.remaining() : 0; + for (ByteBuf buf : queue) + count += buf.readableBytes(); + + return Ints.checkedCast(count); + } + + // TODO:JEB add docs + // TL;DR if there's no Bufs open anywhere here, issue a channle read to try and grab data. + public void maybeIssueRead() + { + if (isEmpty()) + channel.read(); + } + + public boolean isEmpty() + { + return queue.isEmpty() && (buffer == null || !buffer.hasRemaining()); + } + + /** + * {@inheritDoc} + * + * Note: This should invoked on the consuming thread. + */ + @Override + public void close() + { + if (isClosed) + return; + + if (currentBuf != null) + { + currentBuf.release(); + currentBuf = null; + buffer = null; + } + + while (true) + { + try + { + ByteBuf buf = queue.poll(Long.MAX_VALUE, TimeUnit.NANOSECONDS); + if (buf == Unpooled.EMPTY_BUFFER) + break; + else + buf.release(); + } + catch (InterruptedException e) + { + // + } + } + + isClosed = true; + } + + /** + * Mark this stream as closed, but do not release any of the resources. + * + * Note: this is best to be called from the producer thread. + */ + public void requestClosure() + { + queue.add(Unpooled.EMPTY_BUFFER); + } + + // TODO: let's remove this like we did for AsyncChannelOutputPlus + public ByteBufAllocator getAllocator() + { + return channel.alloc(); + } +} diff --git a/src/java/org/apache/cassandra/net/AsyncStreamingOutputPlus.java b/src/java/org/apache/cassandra/net/AsyncStreamingOutputPlus.java new file mode 100644 index 000000000000..a52070e365fb --- /dev/null +++ b/src/java/org/apache/cassandra/net/AsyncStreamingOutputPlus.java @@ -0,0 +1,220 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.net; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.ClosedChannelException; +import java.nio.channels.FileChannel; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import io.netty.channel.Channel; +import io.netty.channel.ChannelPromise; +import io.netty.channel.WriteBufferWaterMark; +import org.apache.cassandra.io.compress.BufferType; +import org.apache.cassandra.io.util.DataOutputStreamPlus; +import org.apache.cassandra.net.SharedDefaultFileRegion.SharedFileChannel; +import org.apache.cassandra.streaming.StreamManager.StreamRateLimiter; +import org.apache.cassandra.utils.memory.BufferPool; + +import static java.lang.Math.min; + +/** + * A {@link DataOutputStreamPlus} that writes ASYNCHRONOUSLY to a Netty Channel. + * + * The close() and flush() methods synchronously wait for pending writes, and will propagate any exceptions + * encountered in writing them to the wire. + * + * The correctness of this class depends on the ChannelPromise we create against a Channel always being completed, + * which appears to be a guarantee provided by Netty so long as the event loop is running. + */ +public class AsyncStreamingOutputPlus extends AsyncChannelOutputPlus +{ + private static final Logger logger = LoggerFactory.getLogger(AsyncStreamingOutputPlus.class); + + final int defaultLowWaterMark; + final int defaultHighWaterMark; + + public AsyncStreamingOutputPlus(Channel channel) + { + super(channel); + WriteBufferWaterMark waterMark = channel.config().getWriteBufferWaterMark(); + this.defaultLowWaterMark = waterMark.low(); + this.defaultHighWaterMark = waterMark.high(); + allocateBuffer(); + } + + private void allocateBuffer() + { + // this buffer is only used for small quantities of data + buffer = BufferPool.getAtLeast(8 << 10, BufferType.OFF_HEAP); + } + + @Override + protected void doFlush(int count) throws IOException + { + if (!channel.isOpen()) + throw new ClosedChannelException(); + + // flush the current backing write buffer only if there's any pending data + ByteBuffer flush = buffer; + if (flush.position() == 0) + return; + + flush.flip(); + int byteCount = flush.limit(); + ChannelPromise promise = beginFlush(byteCount, 0, Integer.MAX_VALUE); + channel.writeAndFlush(GlobalBufferPoolAllocator.wrap(flush), promise); + allocateBuffer(); + } + + public long position() + { + return flushed() + buffer.position(); + } + + public interface BufferSupplier + { + /** + * Request a buffer with at least the given capacity. + * This method may only be invoked once, and the lifetime of buffer it returns will be managed + * by the AsyncChannelOutputPlus it was created for. + */ + ByteBuffer get(int capacity) throws IOException; + } + + public interface Write + { + /** + * Write to a buffer, and flush its contents to the channel. + *

+ * The lifetime of the buffer will be managed by the AsyncChannelOutputPlus you issue this Write to. + * If the method exits successfully, the contents of the buffer will be written to the channel, otherwise + * the buffer will be cleaned and the exception propagated to the caller. + */ + void write(BufferSupplier supplier) throws IOException; + } + + /** + * Provide a lambda that can request a buffer of suitable size, then fill the buffer and have + * that buffer written and flushed to the underlying channel, without having to handle buffer + * allocation, lifetime or cleanup, including in case of exceptions. + *

+ * Any exception thrown by the Write will be propagated to the caller, after any buffer is cleaned up. + */ + public int writeToChannel(Write write, StreamRateLimiter limiter) throws IOException + { + doFlush(0); + class Holder + { + ChannelPromise promise; + ByteBuffer buffer; + } + Holder holder = new Holder(); + + try + { + write.write(size -> { + if (holder.buffer != null) + throw new IllegalStateException("Can only allocate one ByteBuffer"); + limiter.acquire(size); + holder.promise = beginFlush(size, defaultLowWaterMark, defaultHighWaterMark); + holder.buffer = BufferPool.get(size); + return holder.buffer; + }); + } + catch (Throwable t) + { + // we don't currently support cancelling the flush, but at this point we are recoverable if we want + if (holder.buffer != null) + BufferPool.put(holder.buffer); + if (holder.promise != null) + holder.promise.tryFailure(t); + throw t; + } + + ByteBuffer buffer = holder.buffer; + BufferPool.putUnusedPortion(buffer); + + int length = buffer.limit(); + channel.writeAndFlush(GlobalBufferPoolAllocator.wrap(buffer), holder.promise); + return length; + } + + /** + *

+ * Writes all data in file channel to stream, 1MiB at a time, with at most 2MiB in flight at once. + * This method takes ownership of the provided {@code FileChannel}. + *

+ * WARNING: this method blocks only for permission to write to the netty channel; it exits before + * the write is flushed to the network. + */ + public long writeFileToChannel(FileChannel file, StreamRateLimiter limiter) throws IOException + { + // write files in 1MiB chunks, since there may be blocking work performed to fetch it from disk, + // the data is never brought in process and is gated by the wire anyway + return writeFileToChannel(file, limiter, 1 << 20, 1 << 20, 2 << 20); + } + + public long writeFileToChannel(FileChannel file, StreamRateLimiter limiter, int batchSize, int lowWaterMark, int highWaterMark) throws IOException + { + final long length = file.size(); + long bytesTransferred = 0; + + final SharedFileChannel sharedFile = SharedDefaultFileRegion.share(file); + try + { + while (bytesTransferred < length) + { + int toWrite = (int) min(batchSize, length - bytesTransferred); + + limiter.acquire(toWrite); + ChannelPromise promise = beginFlush(toWrite, lowWaterMark, highWaterMark); + + SharedDefaultFileRegion fileRegion = new SharedDefaultFileRegion(sharedFile, bytesTransferred, toWrite); + channel.writeAndFlush(fileRegion, promise); + + if (logger.isTraceEnabled()) + logger.trace("Writing {} bytes at position {} of {}", toWrite, bytesTransferred, length); + bytesTransferred += toWrite; + } + + return bytesTransferred; + } + finally + { + sharedFile.release(); + } + } + + /** + * Discard any buffered data, and the buffers that contain it. + * May be invoked instead of {@link #close()} if we terminate exceptionally. + */ + public void discard() + { + if (buffer != null) + { + BufferPool.put(buffer); + buffer = null; + } + } +} \ No newline at end of file diff --git a/src/java/org/apache/cassandra/net/BackPressureState.java b/src/java/org/apache/cassandra/net/BackPressureState.java index 886c075468b8..de19bf301f57 100644 --- a/src/java/org/apache/cassandra/net/BackPressureState.java +++ b/src/java/org/apache/cassandra/net/BackPressureState.java @@ -27,7 +27,7 @@ public interface BackPressureState /** * Called when a message is sent to a replica. */ - void onMessageSent(MessageOut message); + void onMessageSent(Message message); /** * Called when a response is received from a replica. diff --git a/src/java/org/apache/cassandra/net/BufferPoolAllocator.java b/src/java/org/apache/cassandra/net/BufferPoolAllocator.java new file mode 100644 index 000000000000..8782c030693b --- /dev/null +++ b/src/java/org/apache/cassandra/net/BufferPoolAllocator.java @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.cassandra.net; + +import java.nio.ByteBuffer; + +import io.netty.buffer.AbstractByteBufAllocator; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.buffer.UnpooledUnsafeDirectByteBuf; +import org.apache.cassandra.io.compress.BufferType; +import org.apache.cassandra.utils.memory.BufferPool; + +/** + * A trivial wrapper around BufferPool for integrating with Netty, but retaining ownership of pooling behaviour + * that is integrated into Cassandra's other pooling. + */ +abstract class BufferPoolAllocator extends AbstractByteBufAllocator +{ + BufferPoolAllocator() + { + super(true); + } + + @Override + public boolean isDirectBufferPooled() + { + return true; + } + + /** shouldn't be invoked */ + @Override + protected ByteBuf newHeapBuffer(int minCapacity, int maxCapacity) + { + return Unpooled.buffer(minCapacity, maxCapacity); + } + + @Override + protected ByteBuf newDirectBuffer(int minCapacity, int maxCapacity) + { + ByteBuf result = new Wrapped(this, getAtLeast(minCapacity)); + result.clear(); + return result; + } + + ByteBuffer get(int size) + { + return BufferPool.get(size, BufferType.OFF_HEAP); + } + + ByteBuffer getAtLeast(int size) + { + return BufferPool.getAtLeast(size, BufferType.OFF_HEAP); + } + + void put(ByteBuffer buffer) + { + BufferPool.put(buffer); + } + + void putUnusedPortion(ByteBuffer buffer) + { + BufferPool.putUnusedPortion(buffer); + } + + void release() + { + } + + /** + * A simple extension to UnpooledUnsafeDirectByteBuf that returns buffers to BufferPool on deallocate, + * and permits extracting the buffer from it to take ownership and use directly. + */ + public static class Wrapped extends UnpooledUnsafeDirectByteBuf + { + private ByteBuffer wrapped; + + Wrapped(BufferPoolAllocator allocator, ByteBuffer wrap) + { + super(allocator, wrap, wrap.capacity()); + wrapped = wrap; + } + + @Override + public void deallocate() + { + if (wrapped != null) + BufferPool.put(wrapped); + } + + public ByteBuffer adopt() + { + if (refCnt() > 1) + throw new IllegalStateException(); + ByteBuffer adopt = wrapped; + adopt.position(readerIndex()).limit(writerIndex()); + wrapped = null; + return adopt; + } + } +} diff --git a/src/java/org/apache/cassandra/net/CallbackInfo.java b/src/java/org/apache/cassandra/net/CallbackInfo.java deleted file mode 100644 index f2ed8a10fa97..000000000000 --- a/src/java/org/apache/cassandra/net/CallbackInfo.java +++ /dev/null @@ -1,70 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.cassandra.net; - -import org.apache.cassandra.io.IVersionedSerializer; -import org.apache.cassandra.locator.InetAddressAndPort; - -/** - * Encapsulates the callback information. - * The ability to set the message is useful in cases for when a hint needs - * to be written due to a timeout in the response from a replica. - */ -public class CallbackInfo -{ - protected final InetAddressAndPort target; - protected final IAsyncCallback callback; - protected final IVersionedSerializer serializer; - private final boolean failureCallback; - - /** - * Create CallbackInfo without sent message - * - * @param target target to send message - * @param callback - * @param serializer serializer to deserialize response message - * @param failureCallback True when we have a callback to handle failures - */ - public CallbackInfo(InetAddressAndPort target, IAsyncCallback callback, IVersionedSerializer serializer, boolean failureCallback) - { - this.target = target; - this.callback = callback; - this.serializer = serializer; - this.failureCallback = failureCallback; - } - - public boolean shouldHint() - { - return false; - } - - public boolean isFailureCallback() - { - return failureCallback; - } - - public String toString() - { - return "CallbackInfo(" + - "target=" + target + - ", callback=" + callback + - ", serializer=" + serializer + - ", failureCallback=" + failureCallback + - ')'; - } -} diff --git a/src/java/org/apache/cassandra/net/ChunkedInputPlus.java b/src/java/org/apache/cassandra/net/ChunkedInputPlus.java new file mode 100644 index 000000000000..3aad8d96150e --- /dev/null +++ b/src/java/org/apache/cassandra/net/ChunkedInputPlus.java @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.cassandra.net; + +import java.io.EOFException; + +import com.google.common.collect.Iterators; +import com.google.common.collect.PeekingIterator; + +import org.apache.cassandra.io.util.RebufferingInputStream; + +/** + * A specialised {@link org.apache.cassandra.io.util.DataInputPlus} implementation for deserializing large messages + * that are split over multiple {@link FrameDecoder.Frame}s. + * + * Ensures that every underlying {@link ShareableBytes} frame is released, and promptly so, as frames are consumed. + * + * {@link #close()} MUST be invoked in the end. + */ +class ChunkedInputPlus extends RebufferingInputStream +{ + private final PeekingIterator iter; + + private ChunkedInputPlus(PeekingIterator iter) + { + super(iter.peek().get()); + this.iter = iter; + } + + /** + * Creates a {@link ChunkedInputPlus} from the provided {@link ShareableBytes} buffers. + * + * The provided iterable must contain at least one buffer. + */ + static ChunkedInputPlus of(Iterable buffers) + { + PeekingIterator iter = Iterators.peekingIterator(buffers.iterator()); + if (!iter.hasNext()) + throw new IllegalArgumentException(); + return new ChunkedInputPlus(iter); + } + + @Override + protected void reBuffer() throws EOFException + { + buffer = null; + iter.peek().release(); + iter.next(); + + if (!iter.hasNext()) + throw new EOFException(); + + buffer = iter.peek().get(); + } + + @Override + public void close() + { + buffer = null; + iter.forEachRemaining(ShareableBytes::release); + } + + /** + * Returns the number of unconsumed bytes. Will release any outstanding buffers and consume the underlying iterator. + * + * Should only be used for sanity checking, once the input is no longer needed, as it will implicitly close the input. + */ + int remainder() + { + buffer = null; + + int bytes = 0; + while (iter.hasNext()) + { + ShareableBytes chunk = iter.peek(); + bytes += chunk.remaining(); + chunk.release(); + iter.next(); + } + return bytes; + } +} diff --git a/src/java/org/apache/cassandra/net/CompactEndpointSerializationHelper.java b/src/java/org/apache/cassandra/net/CompactEndpointSerializationHelper.java deleted file mode 100644 index b58ca472540a..000000000000 --- a/src/java/org/apache/cassandra/net/CompactEndpointSerializationHelper.java +++ /dev/null @@ -1,128 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.cassandra.net; - -import java.io.*; -import java.net.Inet4Address; -import java.net.Inet6Address; -import java.net.InetAddress; -import java.nio.ByteBuffer; - -import org.apache.cassandra.io.IVersionedSerializer; -import org.apache.cassandra.io.util.DataInputBuffer; -import org.apache.cassandra.io.util.DataInputPlus; -import org.apache.cassandra.io.util.DataOutputPlus; -import org.apache.cassandra.locator.InetAddressAndPort; -import org.apache.cassandra.streaming.messages.StreamMessage; - -/* - * As of version 4.0 the endpoint description includes a port number as an unsigned short - */ -public class CompactEndpointSerializationHelper implements IVersionedSerializer -{ - public static final IVersionedSerializer instance = new CompactEndpointSerializationHelper(); - - /** - * Streaming uses its own version numbering so we need to ignore it and always use currrent version. - * There is no cross version streaming so it will always use the latest address serialization. - **/ - public static final IVersionedSerializer streamingInstance = new IVersionedSerializer() - { - public void serialize(InetAddressAndPort inetAddressAndPort, DataOutputPlus out, int version) throws IOException - { - instance.serialize(inetAddressAndPort, out, MessagingService.current_version); - } - - public InetAddressAndPort deserialize(DataInputPlus in, int version) throws IOException - { - return instance.deserialize(in, MessagingService.current_version); - } - - public long serializedSize(InetAddressAndPort inetAddressAndPort, int version) - { - return instance.serializedSize(inetAddressAndPort, MessagingService.current_version); - } - }; - - private CompactEndpointSerializationHelper() {} - - public void serialize(InetAddressAndPort endpoint, DataOutputPlus out, int version) throws IOException - { - if (version >= MessagingService.VERSION_40) - { - byte[] buf = endpoint.addressBytes; - out.writeByte(buf.length + 2); - out.write(buf); - out.writeShort(endpoint.port); - } - else - { - byte[] buf = endpoint.addressBytes; - out.writeByte(buf.length); - out.write(buf); - } - } - - public InetAddressAndPort deserialize(DataInputPlus in, int version) throws IOException - { - int size = in.readByte() & 0xFF; - switch(size) - { - //The original pre-4.0 serialiation of just an address - case 4: - case 16: - { - byte[] bytes = new byte[size]; - in.readFully(bytes, 0, bytes.length); - return InetAddressAndPort.getByAddress(bytes); - } - //Address and one port - case 6: - case 18: - { - byte[] bytes = new byte[size - 2]; - in.readFully(bytes); - - int port = in.readShort() & 0xFFFF; - return InetAddressAndPort.getByAddressOverrideDefaults(InetAddress.getByAddress(bytes), bytes, port); - } - default: - throw new AssertionError("Unexpected size " + size); - - } - } - - public long serializedSize(InetAddressAndPort from, int version) - { - //4.0 includes a port number - if (version >= MessagingService.VERSION_40) - { - if (from.address instanceof Inet4Address) - return 1 + 4 + 2; - assert from.address instanceof Inet6Address; - return 1 + 16 + 2; - } - else - { - if (from.address instanceof Inet4Address) - return 1 + 4; - assert from.address instanceof Inet6Address; - return 1 + 16; - } - } -} diff --git a/src/java/org/apache/cassandra/net/IAsyncCallbackWithFailure.java b/src/java/org/apache/cassandra/net/ConnectionCategory.java similarity index 69% rename from src/java/org/apache/cassandra/net/IAsyncCallbackWithFailure.java rename to src/java/org/apache/cassandra/net/ConnectionCategory.java index 2b91f2056c09..d739e9d1ff00 100644 --- a/src/java/org/apache/cassandra/net/IAsyncCallbackWithFailure.java +++ b/src/java/org/apache/cassandra/net/ConnectionCategory.java @@ -15,16 +15,20 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.cassandra.net; -import org.apache.cassandra.exceptions.RequestFailureReason; -import org.apache.cassandra.locator.InetAddressAndPort; +package org.apache.cassandra.net; -public interface IAsyncCallbackWithFailure extends IAsyncCallback +public enum ConnectionCategory { + MESSAGING, STREAMING; + + public boolean isStreaming() + { + return this == STREAMING; + } - /** - * Called when there is an exception on the remote node or timeout happens - */ - void onFailure(InetAddressAndPort from, RequestFailureReason failureReason); + public boolean isMessaging() + { + return this == MESSAGING; + } } diff --git a/src/java/org/apache/cassandra/net/ConnectionType.java b/src/java/org/apache/cassandra/net/ConnectionType.java new file mode 100644 index 000000000000..db83d06856a2 --- /dev/null +++ b/src/java/org/apache/cassandra/net/ConnectionType.java @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.cassandra.net; + +import java.util.List; + +import com.google.common.collect.ImmutableList; + +public enum ConnectionType +{ + LEGACY_MESSAGES (0), // only used for inbound + URGENT_MESSAGES (1), + SMALL_MESSAGES (2), + LARGE_MESSAGES (3), + STREAMING (4); + + public static final List MESSAGING_TYPES = ImmutableList.of(URGENT_MESSAGES, SMALL_MESSAGES, LARGE_MESSAGES); + + public final int id; + + ConnectionType(int id) + { + this.id = id; + } + + public int twoBitID() + { + if (id < 0 || id > 0b11) + throw new AssertionError(); + return id; + } + + public boolean isStreaming() + { + return this == STREAMING; + } + + public boolean isMessaging() + { + return !isStreaming(); + } + + public ConnectionCategory category() + { + return this == STREAMING ? ConnectionCategory.STREAMING : ConnectionCategory.MESSAGING; + } + + private static final ConnectionType[] values = values(); + + public static ConnectionType fromId(int id) + { + return values[id]; + } +} diff --git a/src/java/org/apache/cassandra/net/Crc.java b/src/java/org/apache/cassandra/net/Crc.java new file mode 100644 index 000000000000..dbd26014d566 --- /dev/null +++ b/src/java/org/apache/cassandra/net/Crc.java @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.cassandra.net; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.zip.CRC32; + +import io.netty.buffer.ByteBuf; +import io.netty.util.concurrent.FastThreadLocal; + +class Crc +{ + private static final FastThreadLocal crc32 = new FastThreadLocal() + { + @Override + protected CRC32 initialValue() + { + return new CRC32(); + } + }; + + private static final byte[] initialBytes = new byte[] { (byte) 0xFA, (byte) 0x2D, (byte) 0x55, (byte) 0xCA }; + + static final class InvalidCrc extends IOException + { + InvalidCrc(int read, int computed) + { + super(String.format("Read %d, Computed %d", read, computed)); + } + } + + static CRC32 crc32() + { + CRC32 crc = crc32.get(); + crc.reset(); + crc.update(initialBytes); + return crc; + } + + static int computeCrc32(ByteBuf buffer, int startReaderIndex, int endReaderIndex) + { + CRC32 crc = crc32(); + crc.update(buffer.internalNioBuffer(startReaderIndex, endReaderIndex - startReaderIndex)); + return (int) crc.getValue(); + } + + static int computeCrc32(ByteBuffer buffer, int start, int end) + { + CRC32 crc = crc32(); + updateCrc32(crc, buffer, start, end); + return (int) crc.getValue(); + } + + static void updateCrc32(CRC32 crc, ByteBuffer buffer, int start, int end) + { + int savePosition = buffer.position(); + int saveLimit = buffer.limit(); + buffer.limit(end); + buffer.position(start); + crc.update(buffer); + buffer.limit(saveLimit); + buffer.position(savePosition); + } + + private static final int CRC24_INIT = 0x875060; + /** + * Polynomial chosen from https://users.ece.cmu.edu/~koopman/crc/index.html, by Philip Koopman + * + * This webpage claims a copyright to Philip Koopman, which he licenses under the + * Creative Commons Attribution 4.0 International License (https://creativecommons.org/licenses/by/4.0) + * + * It is unclear if this copyright can extend to a 'fact' such as this specific number, particularly + * as we do not use Koopman's notation to represent the polynomial, but we anyway attribute his work and + * link the terms of his license since they are not incompatible with our usage and we greatly appreciate his work. + * + * This polynomial provides hamming distance of 8 for messages up to length 105 bits; + * we only support 8-64 bits at present, with an expected range of 40-48. + */ + private static final int CRC24_POLY = 0x1974F0B; + + /** + * NOTE: the order of bytes must reach the wire in the same order the CRC is computed, with the CRC + * immediately following in a trailer. Since we read in least significant byte order, if you + * write to a buffer using putInt or putLong, the byte order will be reversed and + * you will lose the guarantee of protection from burst corruptions of 24 bits in length. + * + * Make sure either to write byte-by-byte to the wire, or to use Integer/Long.reverseBytes if you + * write to a BIG_ENDIAN buffer. + * + * See http://users.ece.cmu.edu/~koopman/pubs/ray06_crcalgorithms.pdf + * + * Complain to the ethernet spec writers, for having inverse bit to byte significance order. + * + * Note we use the most naive algorithm here. We support at most 8 bytes, and typically supply + * 5 or fewer, so any efficiency of a table approach is swallowed by the time to hit L3, even + * for a tiny (4bit) table. + * + * @param bytes an up to 8-byte register containing bytes to compute the CRC over + * the bytes AND bits will be read least-significant to most significant. + * @param len the number of bytes, greater than 0 and fewer than 9, to be read from bytes + * @return the least-significant bit AND byte order crc24 using the CRC24_POLY polynomial + */ + static int crc24(long bytes, int len) + { + int crc = CRC24_INIT; + while (len-- > 0) + { + crc ^= (bytes & 0xff) << 16; + bytes >>= 8; + + for (int i = 0; i < 8; i++) + { + crc <<= 1; + if ((crc & 0x1000000) != 0) + crc ^= CRC24_POLY; + } + } + return crc; + } +} diff --git a/src/java/org/apache/cassandra/net/EndpointMessagingVersions.java b/src/java/org/apache/cassandra/net/EndpointMessagingVersions.java new file mode 100644 index 000000000000..e8cf8f68daa5 --- /dev/null +++ b/src/java/org/apache/cassandra/net/EndpointMessagingVersions.java @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.cassandra.net; + +import java.net.UnknownHostException; +import java.util.concurrent.ConcurrentMap; + +import org.cliffc.high_scale_lib.NonBlockingHashMap; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.cassandra.locator.InetAddressAndPort; + +/** + * Map of hosts to their known current messaging versions. + */ +public class EndpointMessagingVersions +{ + private static final Logger logger = LoggerFactory.getLogger(EndpointMessagingVersions.class); + + // protocol versions of the other nodes in the cluster + private final ConcurrentMap versions = new NonBlockingHashMap<>(); + + /** + * @return the last version associated with address, or @param version if this is the first such version + */ + public int set(InetAddressAndPort endpoint, int version) + { + logger.trace("Setting version {} for {}", version, endpoint); + + Integer v = versions.put(endpoint, version); + return v == null ? version : v; + } + + public void reset(InetAddressAndPort endpoint) + { + logger.trace("Resetting version for {}", endpoint); + versions.remove(endpoint); + } + + /** + * Returns the messaging-version as announced by the given node but capped + * to the min of the version as announced by the node and {@link MessagingService#current_version}. + */ + public int get(InetAddressAndPort endpoint) + { + Integer v = versions.get(endpoint); + if (v == null) + { + // we don't know the version. assume current. we'll know soon enough if that was incorrect. + logger.trace("Assuming current protocol version for {}", endpoint); + return MessagingService.current_version; + } + else + return Math.min(v, MessagingService.current_version); + } + + public int get(String endpoint) throws UnknownHostException + { + return get(InetAddressAndPort.getByName(endpoint)); + } + + /** + * Returns the messaging-version exactly as announced by the given endpoint. + */ + public int getRaw(InetAddressAndPort endpoint) + { + Integer v = versions.get(endpoint); + if (v == null) + throw new IllegalStateException("getRawVersion() was called without checking knowsVersion() result first"); + return v; + } + + public boolean knows(InetAddressAndPort endpoint) + { + return versions.containsKey(endpoint); + } +} diff --git a/src/java/org/apache/cassandra/net/ForwardToSerializer.java b/src/java/org/apache/cassandra/net/ForwardToSerializer.java deleted file mode 100644 index c4e8843ec2ea..000000000000 --- a/src/java/org/apache/cassandra/net/ForwardToSerializer.java +++ /dev/null @@ -1,86 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.cassandra.net; - -import java.io.IOException; -import java.util.ArrayList; -import java.util.Iterator; -import java.util.List; - -import org.apache.cassandra.io.IVersionedSerializer; -import org.apache.cassandra.io.util.DataInputBuffer; -import org.apache.cassandra.io.util.DataInputPlus; -import org.apache.cassandra.io.util.DataOutputPlus; -import org.apache.cassandra.locator.InetAddressAndPort; - -public class ForwardToSerializer implements IVersionedSerializer -{ - public static ForwardToSerializer instance = new ForwardToSerializer(); - - private ForwardToSerializer() {} - - public void serialize(ForwardToContainer forwardToContainer, DataOutputPlus out, int version) throws IOException - { - out.writeInt(forwardToContainer.targets.size()); - Iterator iter = forwardToContainer.targets.iterator(); - for (int ii = 0; ii < forwardToContainer.messageIds.length; ii++) - { - CompactEndpointSerializationHelper.instance.serialize(iter.next(), out, version); - out.writeInt(forwardToContainer.messageIds[ii]); - } - } - - public ForwardToContainer deserialize(DataInputPlus in, int version) throws IOException - { - int[] ids = new int[in.readInt()]; - List hosts = new ArrayList<>(ids.length); - for (int ii = 0; ii < ids.length; ii++) - { - hosts.add(CompactEndpointSerializationHelper.instance.deserialize(in, version)); - ids[ii] = in.readInt(); - } - return new ForwardToContainer(hosts, ids); - } - - public long serializedSize(ForwardToContainer forwardToContainer, int version) - { - //Number of forward addresses, 4 bytes per for each id - long size = 4 + - (4 * forwardToContainer.targets.size()); - //Depending on ipv6 or ipv4 the address size is different. - for (InetAddressAndPort forwardTo : forwardToContainer.targets) - { - size += CompactEndpointSerializationHelper.instance.serializedSize(forwardTo, version); - } - - return size; - } - - public static ForwardToContainer fromBytes(byte[] bytes, int version) - { - try (DataInputBuffer input = new DataInputBuffer(bytes)) - { - return instance.deserialize(input, version); - } - catch (IOException e) - { - throw new RuntimeException(e); - } - } -} diff --git a/src/java/org/apache/cassandra/net/ForwardingInfo.java b/src/java/org/apache/cassandra/net/ForwardingInfo.java new file mode 100644 index 000000000000..737da48f40d3 --- /dev/null +++ b/src/java/org/apache/cassandra/net/ForwardingInfo.java @@ -0,0 +1,139 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.cassandra.net; + +import java.io.IOException; +import java.io.Serializable; +import java.util.ArrayList; +import java.util.List; +import java.util.function.BiConsumer; + +import com.google.common.base.Preconditions; +import com.google.common.primitives.Ints; + +import org.apache.cassandra.db.TypeSizes; +import org.apache.cassandra.io.IVersionedSerializer; +import org.apache.cassandra.io.util.DataInputPlus; +import org.apache.cassandra.io.util.DataOutputPlus; +import org.apache.cassandra.locator.InetAddressAndPort; + +import static org.apache.cassandra.locator.InetAddressAndPort.Serializer.inetAddressAndPortSerializer; +import static org.apache.cassandra.net.MessagingService.VERSION_40; +import static org.apache.cassandra.utils.vint.VIntCoding.computeUnsignedVIntSize; + +/** + * A container used to store a node -> message_id map for inter-DC write forwarding. + * We pick one node in each external DC to forward the message to its local peers. + * + * TODO: in the next protocol version only serialize peers, message id will become redundant once 3.0 is out of the picture + */ +public final class ForwardingInfo implements Serializable +{ + final List targets; + final long[] messageIds; + + public ForwardingInfo(List targets, long[] messageIds) + { + Preconditions.checkArgument(targets.size() == messageIds.length); + this.targets = targets; + this.messageIds = messageIds; + } + + /** + * @return {@code true} if all host are to use the same message id, {@code false} otherwise. Starting with 4.0 and + * above, we should be reusing the same id, always, but it won't always be true until 3.0/3.11 are phased out. + */ + public boolean useSameMessageID() + { + if (messageIds.length < 2) + return true; + + long id = messageIds[0]; + for (int i = 1; i < messageIds.length; i++) + if (id != messageIds[i]) + return false; + + return true; + } + + /** + * Apply the provided consumer to all (host, message_id) pairs. + */ + public void forEach(BiConsumer biConsumer) + { + for (int i = 0; i < messageIds.length; i++) + biConsumer.accept(messageIds[i], targets.get(i)); + } + + static final IVersionedSerializer serializer = new IVersionedSerializer() + { + public void serialize(ForwardingInfo forwardTo, DataOutputPlus out, int version) throws IOException + { + long[] ids = forwardTo.messageIds; + List targets = forwardTo.targets; + + int count = ids.length; + if (version >= VERSION_40) + out.writeUnsignedVInt(count); + else + out.writeInt(count); + + for (int i = 0; i < count; i++) + { + inetAddressAndPortSerializer.serialize(targets.get(i), out, version); + if (version >= VERSION_40) + out.writeUnsignedVInt(ids[i]); + else + out.writeInt(Ints.checkedCast(ids[i])); + } + } + + public long serializedSize(ForwardingInfo forwardTo, int version) + { + long[] ids = forwardTo.messageIds; + List targets = forwardTo.targets; + + int count = ids.length; + long size = version >= VERSION_40 ? computeUnsignedVIntSize(count) : TypeSizes.sizeof(count); + + for (int i = 0; i < count; i++) + { + size += inetAddressAndPortSerializer.serializedSize(targets.get(i), version); + size += version >= VERSION_40 ? computeUnsignedVIntSize(ids[i]) : 4; + } + + return size; + } + + public ForwardingInfo deserialize(DataInputPlus in, int version) throws IOException + { + int count = version >= VERSION_40 ? Ints.checkedCast(in.readUnsignedVInt()) : in.readInt(); + + long[] ids = new long[count]; + List targets = new ArrayList<>(count); + + for (int i = 0; i < count; i++) + { + targets.add(inetAddressAndPortSerializer.deserialize(in, version)); + ids[i] = version >= VERSION_40 ? Ints.checkedCast(in.readUnsignedVInt()) : in.readInt(); + } + + return new ForwardingInfo(targets, ids); + } + }; +} diff --git a/src/java/org/apache/cassandra/net/FrameDecoder.java b/src/java/org/apache/cassandra/net/FrameDecoder.java new file mode 100644 index 000000000000..ed96adda32b8 --- /dev/null +++ b/src/java/org/apache/cassandra/net/FrameDecoder.java @@ -0,0 +1,400 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.cassandra.net; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.ArrayDeque; +import java.util.Collection; +import java.util.Deque; + +import com.google.common.annotations.VisibleForTesting; + +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.ChannelPipeline; + +import static org.apache.cassandra.utils.ByteBufferUtil.copyBytes; + +/** + * A Netty inbound handler that decodes incoming frames and passes them forward to + * {@link InboundMessageHandler} for processing. + * + * Handles work stashing, and together with {@link InboundMessageHandler} - flow control. + * + * Unlike most Netty inbound handlers, doesn't use the pipeline to talk to its + * upstream handler. Instead, a {@link FrameProcessor} must be registered with + * the frame decoder, to be invoked on new frames. See {@link #deliver(FrameProcessor)}. + * + * See {@link #activate(FrameProcessor)}, {@link #reactivate()}, and {@link FrameProcessor} + * for flow control implementation. + * + * Five frame decoders currently exist, one used for each connection depending on flags and messaging version: + * 1. {@link FrameDecoderCrc}: + no compression; payload is protected by CRC32 + * 2. {@link FrameDecoderLZ4}: + LZ4 compression with custom frame format; payload is protected by CRC32 + * 3. {@link FrameDecoderUnprotected}: + no compression; no integrity protection + * 4. {@link FrameDecoderLegacy}: + no compression; no integrity protection; turns unframed streams of legacy messages (< 4.0) into frames + * 5. {@link FrameDecoderLegacyLZ4} + * LZ4 compression using standard LZ4 frame format; groups legacy messages (< 4.0) into frames + */ +abstract class FrameDecoder extends ChannelInboundHandlerAdapter +{ + private static final FrameProcessor NO_PROCESSOR = + frame -> { throw new IllegalStateException("Frame processor invoked on an unregistered FrameDecoder"); }; + + private static final FrameProcessor CLOSED_PROCESSOR = + frame -> { throw new IllegalStateException("Frame processor invoked on a closed FrameDecoder"); }; + + interface FrameProcessor + { + /** + * Frame processor that the frames should be handed off to. + * + * @return true if more frames can be taken by the processor, false if the decoder should pause until + * it's explicitly resumed. + */ + boolean process(Frame frame) throws IOException; + } + + abstract static class Frame + { + final boolean isSelfContained; + final int frameSize; + + Frame(boolean isSelfContained, int frameSize) + { + this.isSelfContained = isSelfContained; + this.frameSize = frameSize; + } + + abstract void release(); + abstract boolean isConsumed(); + } + + /** + * The payload bytes of a complete frame, i.e. a frame stripped of its headers and trailers, + * with any verification supported by the protocol confirmed. + * + * If {@code isSelfContained} the payload contains one or more {@link Message}, all of which + * may be parsed entirely from the bytes provided. Otherwise, only a part of exactly one + * {@link Message} is contained in the payload; it can be relied upon that this partial {@link Message} + * will only be delivered in its own unique {@link Frame}. + */ + final static class IntactFrame extends Frame + { + final ShareableBytes contents; + + IntactFrame(boolean isSelfContained, ShareableBytes contents) + { + super(isSelfContained, contents.remaining()); + this.contents = contents; + } + + void release() + { + contents.release(); + } + + boolean isConsumed() + { + return !contents.hasRemaining(); + } + + void consume() + { + contents.consume(); + } + } + + /** + * A corrupted frame was encountered; this represents the knowledge we have about this frame, + * and whether or not the stream is recoverable. + * + * Generally we consider a frame with corrupted header as unrecoverable, and frames with intact header, + * but corrupted payload - as recoverable, since we know and can skip payload size. + * + * {@link InboundMessageHandler} further has its own idea of which frames are and aren't recoverable. + * A recoverable {@link CorruptFrame} can be considered unrecoverable by {@link InboundMessageHandler} + * if it's the first frame of a large message (isn't self contained). + */ + final static class CorruptFrame extends Frame + { + final int readCRC, computedCRC; + + CorruptFrame(boolean isSelfContained, int frameSize, int readCRC, int computedCRC) + { + super(isSelfContained, frameSize); + this.readCRC = readCRC; + this.computedCRC = computedCRC; + } + + static CorruptFrame recoverable(boolean isSelfContained, int frameSize, int readCRC, int computedCRC) + { + return new CorruptFrame(isSelfContained, frameSize, readCRC, computedCRC); + } + + static CorruptFrame unrecoverable(int readCRC, int computedCRC) + { + return new CorruptFrame(false, Integer.MIN_VALUE, readCRC, computedCRC); + } + + boolean isRecoverable() + { + return frameSize != Integer.MIN_VALUE; + } + + void release() { } + + boolean isConsumed() + { + return true; + } + } + + protected final BufferPoolAllocator allocator; + + @VisibleForTesting + final Deque frames = new ArrayDeque<>(4); + ByteBuffer stash; + + private boolean isActive; + private boolean isClosed; + private ChannelHandlerContext ctx; + private FrameProcessor processor = NO_PROCESSOR; + + FrameDecoder(BufferPoolAllocator allocator) + { + this.allocator = allocator; + } + + abstract void decode(Collection into, ShareableBytes bytes); + abstract void addLastTo(ChannelPipeline pipeline); + + /** + * For use by InboundMessageHandler (or other upstream handlers) that want to start receiving frames. + */ + void activate(FrameProcessor processor) + { + if (this.processor != NO_PROCESSOR) + throw new IllegalStateException("Attempted to activate an already active FrameDecoder"); + + this.processor = processor; + + isActive = true; + ctx.read(); + } + + /** + * For use by InboundMessageHandler (or other upstream handlers) that want to resume + * receiving frames after previously indicating that processing should be paused. + */ + void reactivate() throws IOException + { + if (isActive) + throw new IllegalStateException("Tried to reactivate an already active FrameDecoder"); + + if (deliver(processor)) + { + isActive = true; + onExhausted(); + } + } + + /** + * For use by InboundMessageHandler (or other upstream handlers) that want to resume + * receiving frames after previously indicating that processing should be paused. + * + * Does not reactivate processing or reading from the wire, but permits processing as many frames (or parts thereof) + * that are already waiting as the processor requires. + */ + void processBacklog(FrameProcessor processor) throws IOException + { + deliver(processor); + } + + /** + * For use by InboundMessageHandler (or other upstream handlers) that want to permanently + * stop receiving frames, e.g. because of an exception caught. + */ + void discard() + { + isActive = false; + processor = CLOSED_PROCESSOR; + if (stash != null) + { + ByteBuffer bytes = stash; + stash = null; + allocator.put(bytes); + } + while (!frames.isEmpty()) + frames.poll().release(); + } + + /** + * Called by Netty pipeline when a new message arrives; we anticipate in normal operation + * this will receive messages of type {@link BufferPoolAllocator.Wrapped} or + * {@link BufferPoolAllocator.Wrapped}. + * + * These buffers are unwrapped and passed to {@link #decode(Collection, ShareableBytes)}, + * which collects decoded frames into {@link #frames}, which we send upstream in {@link #deliver} + */ + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws IOException + { + if (msg instanceof BufferPoolAllocator.Wrapped) + { + ByteBuffer buf = ((BufferPoolAllocator.Wrapped) msg).adopt(); + // netty will probably have mis-predicted the space needed + allocator.putUnusedPortion(buf); + channelRead(ShareableBytes.wrap(buf)); + } + else if (msg instanceof ShareableBytes) // legacy LZ4 decoder + { + channelRead((ShareableBytes) msg); + } + else + { + throw new IllegalArgumentException(); + } + } + + void channelRead(ShareableBytes bytes) throws IOException + { + decode(frames, bytes); + + if (isActive) isActive = deliver(processor); + } + + @Override + public void channelReadComplete(ChannelHandlerContext ctx) + { + if (isActive) + onExhausted(); + } + + /** + * Only to be invoked when frames.isEmpty(). + * + * If we have been closed, we will now propagate up the channelInactive notification, + * and otherwise we will ask the channel for more data. + */ + private void onExhausted() + { + if (isClosed) + close(); + else + ctx.read(); + } + + /** + * Deliver any waiting frames, including those that were incompletely read last time, to the provided processor + * until the processor returns {@code false}, or we finish the backlog. + * + * Propagate the final return value of the processor. + */ + private boolean deliver(FrameProcessor processor) throws IOException + { + boolean deliver = true; + while (deliver && !frames.isEmpty()) + { + Frame frame = frames.peek(); + deliver = processor.process(frame); + + assert !deliver || frame.isConsumed(); + if (deliver || frame.isConsumed()) + { + frames.poll(); + frame.release(); + } + } + return deliver; + } + + void stash(ShareableBytes in, int stashLength, int begin, int length) + { + ByteBuffer out = allocator.getAtLeast(stashLength); + copyBytes(in.get(), begin, out, 0, length); + out.position(length); + stash = out; + } + + @Override + public void handlerAdded(ChannelHandlerContext ctx) + { + this.ctx = ctx; + ctx.channel().config().setAutoRead(false); + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) + { + isClosed = true; + if (frames.isEmpty()) + close(); + } + + private void close() + { + discard(); + ctx.fireChannelInactive(); + allocator.release(); + } + + /** + * Utility: fill {@code out} from {@code in} up to {@code toOutPosition}, + * updating the position of both buffers with the result + * @return true if there were sufficient bytes to fill to {@code toOutPosition} + */ + static boolean copyToSize(ByteBuffer in, ByteBuffer out, int toOutPosition) + { + int bytesToSize = toOutPosition - out.position(); + if (bytesToSize <= 0) + return true; + + if (bytesToSize > in.remaining()) + { + out.put(in); + return false; + } + + copyBytes(in, in.position(), out, out.position(), bytesToSize); + in.position(in.position() + bytesToSize); + out.position(toOutPosition); + return true; + } + + /** + * @return {@code in} if has sufficient capacity, otherwise + * a replacement from {@code BufferPool} that {@code in} is copied into + */ + ByteBuffer ensureCapacity(ByteBuffer in, int capacity) + { + if (in.capacity() >= capacity) + return in; + + ByteBuffer out = allocator.getAtLeast(capacity); + in.flip(); + out.put(in); + allocator.put(in); + return out; + } +} diff --git a/src/java/org/apache/cassandra/net/FrameDecoderCrc.java b/src/java/org/apache/cassandra/net/FrameDecoderCrc.java new file mode 100644 index 000000000000..7cd52ac7df63 --- /dev/null +++ b/src/java/org/apache/cassandra/net/FrameDecoderCrc.java @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.cassandra.net; + +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.Collection; +import java.util.zip.CRC32; + +import io.netty.channel.ChannelPipeline; + +import static org.apache.cassandra.net.Crc.*; +import static org.apache.cassandra.net.Crc.updateCrc32; + +/** + * Framing format that protects integrity of data in movement with CRCs (of both header and payload). + * + * Every on-wire frame contains: + * 1. Payload length (17 bits) + * 2. {@code isSelfContained} flag (1 bit) + * 3. Header padding (6 bits) + * 4. CRC24 of the header (24 bits) + * 5. Payload (up to 2 ^ 17 - 1 bits) + * 6. Payload CRC32 (32 bits) + * + * 0 1 2 3 + * 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | Payload Length |C| | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * CRC24 of Header | | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + + * | | + * + + + * | Payload | + * + + + * | | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | CRC32 of Payload | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + */ +final class FrameDecoderCrc extends FrameDecoderWith8bHeader +{ + private FrameDecoderCrc(BufferPoolAllocator allocator) + { + super(allocator); + } + + public static FrameDecoderCrc create(BufferPoolAllocator allocator) + { + return new FrameDecoderCrc(allocator); + } + + static final int HEADER_LENGTH = 6; + private static final int TRAILER_LENGTH = 4; + private static final int HEADER_AND_TRAILER_LENGTH = 10; + + static boolean isSelfContained(long header6b) + { + return 0 != (header6b & (1L << 17)); + } + + static int payloadLength(long header6b) + { + return ((int) header6b) & 0x1FFFF; + } + + private static int headerCrc(long header6b) + { + return ((int) (header6b >>> 24)) & 0xFFFFFF; + } + + static long readHeader6b(ByteBuffer frame, int begin) + { + long header6b; + if (frame.limit() - begin >= 8) + { + header6b = frame.getLong(begin); + if (frame.order() == ByteOrder.BIG_ENDIAN) + header6b = Long.reverseBytes(header6b); + header6b &= 0xffffffffffffL; + } + else + { + header6b = 0; + for (int i = 0 ; i < HEADER_LENGTH ; ++i) + header6b |= (0xffL & frame.get(begin + i)) << (8 * i); + } + return header6b; + } + + static CorruptFrame verifyHeader6b(long header6b) + { + int computeLengthCrc = crc24(header6b, 3); + int readLengthCrc = headerCrc(header6b); + + return readLengthCrc == computeLengthCrc ? null : CorruptFrame.unrecoverable(readLengthCrc, computeLengthCrc); + } + + final long readHeader(ByteBuffer frame, int begin) + { + return readHeader6b(frame, begin); + } + + final CorruptFrame verifyHeader(long header6b) + { + return verifyHeader6b(header6b); + } + + final int frameLength(long header6b) + { + return payloadLength(header6b) + HEADER_AND_TRAILER_LENGTH; + } + + final Frame unpackFrame(ShareableBytes bytes, int begin, int end, long header6b) + { + ByteBuffer in = bytes.get(); + boolean isSelfContained = isSelfContained(header6b); + + CRC32 crc = crc32(); + int readFullCrc = in.getInt(end - TRAILER_LENGTH); + if (in.order() == ByteOrder.BIG_ENDIAN) + readFullCrc = Integer.reverseBytes(readFullCrc); + + updateCrc32(crc, in, begin + HEADER_LENGTH, end - TRAILER_LENGTH); + int computeFullCrc = (int) crc.getValue(); + + if (readFullCrc != computeFullCrc) + return CorruptFrame.recoverable(isSelfContained, (end - begin) - HEADER_AND_TRAILER_LENGTH, readFullCrc, computeFullCrc); + + return new IntactFrame(isSelfContained, bytes.slice(begin + HEADER_LENGTH, end - TRAILER_LENGTH)); + } + + void decode(Collection into, ShareableBytes bytes) + { + decode(into, bytes, HEADER_LENGTH); + } + + void addLastTo(ChannelPipeline pipeline) + { + pipeline.addLast("frameDecoderCrc", this); + } +} diff --git a/src/java/org/apache/cassandra/net/FrameDecoderLZ4.java b/src/java/org/apache/cassandra/net/FrameDecoderLZ4.java new file mode 100644 index 000000000000..941139a0155f --- /dev/null +++ b/src/java/org/apache/cassandra/net/FrameDecoderLZ4.java @@ -0,0 +1,165 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.cassandra.net; + +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.Collection; +import java.util.zip.CRC32; + +import io.netty.channel.ChannelPipeline; +import net.jpountz.lz4.LZ4Factory; +import net.jpountz.lz4.LZ4FastDecompressor; + +import static org.apache.cassandra.net.Crc.*; + +/** + * Framing format that compresses payloads with LZ4, and protects integrity of data in movement with CRCs + * (of both header and payload). + * + * Every on-wire frame contains: + * 1. Compressed length (17 bits) + * 2. Uncompressed length (17 bits) + * 3. {@code isSelfContained} flag (1 bit) + * 4. Header padding (5 bits) + * 5. CRC24 of Header contents (24 bits) + * 6. Compressed Payload (up to 2 ^ 17 - 1 bits) + * 7. CRC32 of Compressed Payload (32 bits) + * + * 0 1 2 3 + * 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | Compressed Length | Uncompressed Length + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * |C| | CRC24 of Header | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | | + * + + + * | Compressed Payload | + * + + + * | | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | CRC32 of Compressed Payload | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + */ +final class FrameDecoderLZ4 extends FrameDecoderWith8bHeader +{ + public static FrameDecoderLZ4 fast(BufferPoolAllocator allocator) + { + return new FrameDecoderLZ4(allocator, LZ4Factory.fastestInstance().fastDecompressor()); + } + + private static final int HEADER_LENGTH = 8; + private static final int TRAILER_LENGTH = 4; + private static final int HEADER_AND_TRAILER_LENGTH = 12; + + private static int compressedLength(long header8b) + { + return ((int) header8b) & 0x1FFFF; + } + private static int uncompressedLength(long header8b) + { + return ((int) (header8b >>> 17)) & 0x1FFFF; + } + private static boolean isSelfContained(long header8b) + { + return 0 != (header8b & (1L << 34)); + } + private static int headerCrc(long header8b) + { + return ((int) (header8b >>> 40)) & 0xFFFFFF; + } + + private final LZ4FastDecompressor decompressor; + + private FrameDecoderLZ4(BufferPoolAllocator allocator, LZ4FastDecompressor decompressor) + { + super(allocator); + this.decompressor = decompressor; + } + + final long readHeader(ByteBuffer frame, int begin) + { + long header8b = frame.getLong(begin); + if (frame.order() == ByteOrder.BIG_ENDIAN) + header8b = Long.reverseBytes(header8b); + return header8b; + } + + final CorruptFrame verifyHeader(long header8b) + { + int computeLengthCrc = crc24(header8b, 5); + int readLengthCrc = headerCrc(header8b); + + return readLengthCrc == computeLengthCrc ? null : CorruptFrame.unrecoverable(readLengthCrc, computeLengthCrc); + } + + final int frameLength(long header8b) + { + return compressedLength(header8b) + HEADER_AND_TRAILER_LENGTH; + } + + final Frame unpackFrame(ShareableBytes bytes, int begin, int end, long header8b) + { + ByteBuffer input = bytes.get(); + + boolean isSelfContained = isSelfContained(header8b); + int uncompressedLength = uncompressedLength(header8b); + + CRC32 crc = crc32(); + int readFullCrc = input.getInt(end - TRAILER_LENGTH); + if (input.order() == ByteOrder.BIG_ENDIAN) + readFullCrc = Integer.reverseBytes(readFullCrc); + + updateCrc32(crc, input, begin + HEADER_LENGTH, end - TRAILER_LENGTH); + int computeFullCrc = (int) crc.getValue(); + + if (readFullCrc != computeFullCrc) + return CorruptFrame.recoverable(isSelfContained, uncompressedLength, readFullCrc, computeFullCrc); + + if (uncompressedLength == 0) + { + return new IntactFrame(isSelfContained, bytes.slice(begin + HEADER_LENGTH, end - TRAILER_LENGTH)); + } + else + { + ByteBuffer out = allocator.get(uncompressedLength); + try + { + decompressor.decompress(input, begin + HEADER_LENGTH, out, 0, uncompressedLength); + return new IntactFrame(isSelfContained, ShareableBytes.wrap(out)); + } + catch (Throwable t) + { + allocator.put(out); + throw t; + } + } + } + + void decode(Collection into, ShareableBytes bytes) + { + // TODO: confirm in assembly output that we inline the relevant nested method calls + decode(into, bytes, HEADER_LENGTH); + } + + void addLastTo(ChannelPipeline pipeline) + { + pipeline.addLast("frameDecoderLZ4", this); + } +} diff --git a/src/java/org/apache/cassandra/net/FrameDecoderLegacy.java b/src/java/org/apache/cassandra/net/FrameDecoderLegacy.java new file mode 100644 index 000000000000..a3d7bc593ea4 --- /dev/null +++ b/src/java/org/apache/cassandra/net/FrameDecoderLegacy.java @@ -0,0 +1,184 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.cassandra.net; + +import java.nio.ByteBuffer; +import java.util.Collection; + +import io.netty.channel.ChannelPipeline; + +import static java.lang.Math.max; +import static org.apache.cassandra.net.OutboundConnections.LARGE_MESSAGE_THRESHOLD; + +/** + * {@link InboundMessageHandler} operates on frames that adhere to a certain contract + * (see {@link FrameDecoder.IntactFrame} and {@link FrameDecoder.CorruptFrame} javadoc). + * + * Legacy (pre-4.0) messaging protocol does not natively support framing, however. The job + * of {@link FrameDecoderLegacy} is turn a raw stream of messages, serialized back to back, + * into a sequence of frames that adhere to 4.0+ conventions. + */ +class FrameDecoderLegacy extends FrameDecoder +{ + private final int messagingVersion; + + private int remainingBytesInLargeMessage = 0; + + FrameDecoderLegacy(BufferPoolAllocator allocator, int messagingVersion) + { + super(allocator); + this.messagingVersion = messagingVersion; + } + + final void decode(Collection into, ShareableBytes newBytes) + { + ByteBuffer in = newBytes.get(); + try + { + if (stash != null) + { + int length = Message.serializer.inferMessageSize(stash, 0, stash.position(), messagingVersion); + while (length < 0) + { + if (!in.hasRemaining()) + return; + + if (stash.position() == stash.capacity()) + stash = ensureCapacity(stash, stash.capacity() * 2); + copyToSize(in, stash, stash.capacity()); + + length = Message.serializer.inferMessageSize(stash, 0, stash.position(), messagingVersion); + if (length >= 0 && length < stash.position()) + { + int excess = stash.position() - length; + in.position(in.position() - excess); + stash.position(length); + } + } + + final boolean isSelfContained; + if (length <= LARGE_MESSAGE_THRESHOLD) + { + isSelfContained = true; + + if (length > stash.capacity()) + stash = ensureCapacity(stash, length); + + stash.limit(length); + allocator.putUnusedPortion(stash); // we may be over capacity from earlier doubling + if (!copyToSize(in, stash, length)) + return; + } + else + { + isSelfContained = false; + remainingBytesInLargeMessage = length - stash.position(); + + stash.limit(stash.position()); + allocator.putUnusedPortion(stash); + } + + stash.flip(); + assert !isSelfContained || stash.limit() == length; + ShareableBytes stashed = ShareableBytes.wrap(stash); + into.add(new IntactFrame(isSelfContained, stashed)); + stash = null; + } + + if (remainingBytesInLargeMessage > 0) + { + if (remainingBytesInLargeMessage >= newBytes.remaining()) + { + remainingBytesInLargeMessage -= newBytes.remaining(); + into.add(new IntactFrame(false, newBytes.sliceAndConsume(newBytes.remaining()))); + return; + } + else + { + Frame frame = new IntactFrame(false, newBytes.sliceAndConsume(remainingBytesInLargeMessage)); + remainingBytesInLargeMessage = 0; + into.add(frame); + } + } + + // we loop incrementing our end pointer until we have no more complete messages, + // at which point we slice the complete messages, and stash the remainder + int begin = in.position(); + int end = begin; + int limit = in.limit(); + + if (begin == limit) + return; + + while (true) + { + int length = Message.serializer.inferMessageSize(in, end, limit, messagingVersion); + + if (length >= 0) + { + if (end + length <= limit) + { + // we have a complete message, so just bump our end pointer + end += length; + + // if we have more bytes, continue to look for another message + if (end < limit) + continue; + + // otherwise reset length, as we have accounted for it in end + length = 0; + } + } + + // we are done; if we have found any complete messages, slice them all into a single frame + if (begin < end) + into.add(new IntactFrame(true, newBytes.slice(begin, end))); + + // now consider stashing anything leftover + if (length < 0) + { + stash(newBytes, max(64, limit - end), end, limit - end); + } + else if (length > LARGE_MESSAGE_THRESHOLD) + { + remainingBytesInLargeMessage = length - (limit - end); + Frame frame = new IntactFrame(false, newBytes.slice(end, limit)); + into.add(frame); + } + else if (length > 0) + { + stash(newBytes, length, end, limit - end); + } + break; + } + } + catch (Message.InvalidLegacyProtocolMagic e) + { + into.add(CorruptFrame.unrecoverable(e.read, Message.PROTOCOL_MAGIC)); + } + finally + { + newBytes.release(); + } + } + + void addLastTo(ChannelPipeline pipeline) + { + pipeline.addLast("frameDecoderNone", this); + } +} diff --git a/src/java/org/apache/cassandra/net/FrameDecoderLegacyLZ4.java b/src/java/org/apache/cassandra/net/FrameDecoderLegacyLZ4.java new file mode 100644 index 000000000000..f2556a5c880e --- /dev/null +++ b/src/java/org/apache/cassandra/net/FrameDecoderLegacyLZ4.java @@ -0,0 +1,377 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.cassandra.net; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.ArrayDeque; +import java.util.Collection; +import java.util.Deque; + +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.ChannelPipeline; +import io.netty.handler.codec.compression.Lz4FrameDecoder; +import net.jpountz.lz4.LZ4Factory; +import net.jpountz.lz4.LZ4FastDecompressor; +import net.jpountz.xxhash.XXHash32; +import net.jpountz.xxhash.XXHashFactory; +import org.apache.cassandra.utils.memory.BufferPool; + +import static java.lang.Integer.reverseBytes; +import static java.lang.String.format; +import static org.apache.cassandra.net.LegacyLZ4Constants.*; +import static org.apache.cassandra.utils.ByteBufferUtil.copyBytes; + +/** + * A {@link FrameDecoder} consisting of two chained handlers: + * 1. A legacy LZ4 block decoder, described below in the description of {@link LZ4Decoder}, followed by + * 2. An instance of {@link FrameDecoderLegacy} - transforming the raw messages in the uncompressed stream + * into properly formed frames expected by {@link InboundMessageHandler} + */ +class FrameDecoderLegacyLZ4 extends FrameDecoderLegacy +{ + FrameDecoderLegacyLZ4(BufferPoolAllocator allocator, int messagingVersion) + { + super(allocator, messagingVersion); + } + + @Override + void addLastTo(ChannelPipeline pipeline) + { + pipeline.addLast("legacyLZ4Decoder", new LZ4Decoder(allocator)); + pipeline.addLast("frameDecoderNone", this); + } + + /** + * An implementation of LZ4 decoder, used for legacy (3.0, 3.11) connections. + * + * Netty's provided implementation - {@link Lz4FrameDecoder} couldn't be reused for + * two reasons: + * 1. It has very poor performance when coupled with xxHash, which we use for legacy connections - + * allocating a single-byte array and making a JNI call for every byte of the payload + * 2. It was tricky to efficiently integrate with upstream {@link FrameDecoder}, and impossible + * to make it play nicely with flow control - Netty's implementation, based on + * {@link io.netty.handler.codec.ByteToMessageDecoder}, would potentially keep triggering + * reads on its own volition for as long as its last read had no completed frames to supply + * - defying our goal to only ever trigger channel reads when explicitly requested + * + * Since the original LZ4 block format does not contains size of compressed block and size of original data + * this encoder uses format like LZ4 Java library + * written by Adrien Grand and approved by Yann Collet (author of original LZ4 library), as implemented by + * Netty's {@link Lz4FrameDecoder}, but adapted for our interaction model. + * + * 0 1 2 3 + * 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | | + * + Magic + + * | | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * |T| Compressed Length + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | Uncompressed Length + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | xxHash32 of Uncompressed Payload + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | | + * +-+ + + * | | + * + Payload + + * | | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + */ + private static class LZ4Decoder extends ChannelInboundHandlerAdapter + { + private static final XXHash32 xxhash = + XXHashFactory.fastestInstance().hash32(); + + private static final LZ4FastDecompressor decompressor = + LZ4Factory.fastestInstance().fastDecompressor(); + + private final BufferPoolAllocator allocator; + + LZ4Decoder(BufferPoolAllocator allocator) + { + this.allocator = allocator; + } + + private final Deque frames = new ArrayDeque<>(4); + + // total # of frames decoded between two subsequent invocations of channelReadComplete() + private int decodedFrameCount = 0; + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws CorruptLZ4Frame + { + assert msg instanceof BufferPoolAllocator.Wrapped; + ByteBuffer buf = ((BufferPoolAllocator.Wrapped) msg).adopt(); + // netty will probably have mis-predicted the space needed + BufferPool.putUnusedPortion(buf); + + CorruptLZ4Frame error = null; + try + { + decode(frames, ShareableBytes.wrap(buf)); + } + catch (CorruptLZ4Frame e) + { + error = e; + } + finally + { + decodedFrameCount += frames.size(); + while (!frames.isEmpty()) + ctx.fireChannelRead(frames.poll()); + } + + if (null != error) + throw error; + } + + @Override + public void channelReadComplete(ChannelHandlerContext ctx) + { + /* + * If no frames have been decoded from the entire batch of channelRead() calls, + * then we must trigger another channel read explicitly, or else risk stalling + * forever without bytes to complete the current in-flight frame. + */ + if (null != stash && decodedFrameCount == 0 && !ctx.channel().config().isAutoRead()) + ctx.read(); + + decodedFrameCount = 0; + ctx.fireChannelReadComplete(); + } + + private void decode(Collection into, ShareableBytes newBytes) throws CorruptLZ4Frame + { + try + { + doDecode(into, newBytes); + } + finally + { + newBytes.release(); + } + } + + private void doDecode(Collection into, ShareableBytes newBytes) throws CorruptLZ4Frame + { + ByteBuffer in = newBytes.get(); + + if (null != stash) + { + if (!copyToSize(in, stash, HEADER_LENGTH)) + return; + + header.read(stash, 0); + header.validate(); + + int frameLength = header.frameLength(); + stash = ensureCapacity(stash, frameLength); + + if (!copyToSize(in, stash, frameLength)) + return; + + stash.flip(); + ShareableBytes stashed = ShareableBytes.wrap(stash); + stash = null; + + try + { + into.add(decompressFrame(stashed, 0, frameLength, header)); + } + finally + { + stashed.release(); + } + } + + int begin = in.position(); + int limit = in.limit(); + while (begin < limit) + { + int remaining = limit - begin; + if (remaining < HEADER_LENGTH) + { + stash(newBytes, HEADER_LENGTH, begin, remaining); + return; + } + + header.read(in, begin); + header.validate(); + + int frameLength = header.frameLength(); + if (remaining < frameLength) + { + stash(newBytes, frameLength, begin, remaining); + return; + } + + into.add(decompressFrame(newBytes, begin, begin + frameLength, header)); + begin += frameLength; + } + } + + private ShareableBytes decompressFrame(ShareableBytes bytes, int begin, int end, Header header) throws CorruptLZ4Frame + { + ByteBuffer buf = bytes.get(); + + if (header.uncompressedLength == 0) + return bytes.slice(begin + HEADER_LENGTH, end); + + if (!header.isCompressed()) + { + validateChecksum(buf, begin + HEADER_LENGTH, header); + return bytes.slice(begin + HEADER_LENGTH, end); + } + + ByteBuffer out = allocator.get(header.uncompressedLength); + try + { + decompressor.decompress(buf, begin + HEADER_LENGTH, out, 0, header.uncompressedLength); + validateChecksum(out, 0, header); + return ShareableBytes.wrap(out); + } + catch (Throwable t) + { + BufferPool.put(out); + throw t; + } + } + + private void validateChecksum(ByteBuffer buf, int begin, Header header) throws CorruptLZ4Frame + { + int checksum = xxhash.hash(buf, begin, header.uncompressedLength, XXHASH_SEED) & XXHASH_MASK; + if (checksum != header.checksum) + except("Invalid checksum detected: %d (expected: %d)", checksum, header.checksum); + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) + { + if (null != stash) + { + BufferPool.put(stash); + stash = null; + } + + while (!frames.isEmpty()) + frames.poll().release(); + + ctx.fireChannelInactive(); + } + + /* reusable container for deserialized header fields */ + private static final class Header + { + long magicNumber; + byte token; + int compressedLength; + int uncompressedLength; + int checksum; + + int frameLength() + { + return HEADER_LENGTH + compressedLength; + } + + boolean isCompressed() + { + return (token & 0xF0) == 0x20; + } + + int maxUncompressedLength() + { + return 1 << ((token & 0x0F) + 10); + } + + void read(ByteBuffer in, int begin) + { + magicNumber = in.getLong(begin + MAGIC_NUMBER_OFFSET ); + token = in.get (begin + TOKEN_OFFSET ); + compressedLength = reverseBytes(in.getInt (begin + COMPRESSED_LENGTH_OFFSET )); + uncompressedLength = reverseBytes(in.getInt (begin + UNCOMPRESSED_LENGTH_OFFSET)); + checksum = reverseBytes(in.getInt (begin + CHECKSUM_OFFSET )); + } + + void validate() throws CorruptLZ4Frame + { + if (magicNumber != MAGIC_NUMBER) + except("Invalid magic number at the beginning of an LZ4 block: %d", magicNumber); + + int blockType = token & 0xF0; + if (!(blockType == BLOCK_TYPE_COMPRESSED || blockType == BLOCK_TYPE_NON_COMPRESSED)) + except("Invalid block type encountered: %d", blockType); + + if (compressedLength < 0 || compressedLength > MAX_BLOCK_LENGTH) + except("Invalid compressedLength: %d (expected: 0-%d)", compressedLength, MAX_BLOCK_LENGTH); + + if (uncompressedLength < 0 || uncompressedLength > maxUncompressedLength()) + except("Invalid uncompressedLength: %d (expected: 0-%d)", uncompressedLength, maxUncompressedLength()); + + if ( uncompressedLength == 0 && compressedLength != 0 + || uncompressedLength != 0 && compressedLength == 0 + || !isCompressed() && uncompressedLength != compressedLength) + { + except("Stream corrupted: compressedLength(%d) and decompressedLength(%d) mismatch", compressedLength, uncompressedLength); + } + } + } + private final Header header = new Header(); + + /** + * @return {@code in} if has sufficient capacity, otherwise a replacement from {@code BufferPool} that {@code in} is copied into + */ + private ByteBuffer ensureCapacity(ByteBuffer in, int capacity) + { + if (in.capacity() >= capacity) + return in; + + ByteBuffer out = allocator.getAtLeast(capacity); + in.flip(); + out.put(in); + BufferPool.put(in); + return out; + } + + private ByteBuffer stash; + + private void stash(ShareableBytes in, int stashLength, int begin, int length) + { + ByteBuffer out = allocator.getAtLeast(stashLength); + copyBytes(in.get(), begin, out, 0, length); + out.position(length); + stash = out; + } + + static final class CorruptLZ4Frame extends IOException + { + CorruptLZ4Frame(String message) + { + super(message); + } + } + + private static void except(String format, Object... args) throws CorruptLZ4Frame + { + throw new CorruptLZ4Frame(format(format, args)); + } + } +} diff --git a/src/java/org/apache/cassandra/net/FrameDecoderUnprotected.java b/src/java/org/apache/cassandra/net/FrameDecoderUnprotected.java new file mode 100644 index 000000000000..44414e3ba3e8 --- /dev/null +++ b/src/java/org/apache/cassandra/net/FrameDecoderUnprotected.java @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.cassandra.net; + +import java.nio.ByteBuffer; +import java.util.Collection; + +import io.netty.channel.ChannelPipeline; + +import static org.apache.cassandra.net.FrameDecoderCrc.HEADER_LENGTH; +import static org.apache.cassandra.net.FrameDecoderCrc.isSelfContained; +import static org.apache.cassandra.net.FrameDecoderCrc.payloadLength; +import static org.apache.cassandra.net.FrameDecoderCrc.readHeader6b; +import static org.apache.cassandra.net.FrameDecoderCrc.verifyHeader6b; + +/** + * A frame decoder for unprotected frames, i.e. those without any modification or payload protection. + * This is non-standard, and useful for systems that have a trusted transport layer that want + * to avoid incurring the (very low) cost of computing a CRC. All we do is accumulate the bytes + * of the frame, verify the frame header, and pass through the bytes stripped of the header. + * + * Every on-wire frame contains: + * 1. Payload length (17 bits) + * 2. {@code isSelfContained} flag (1 bit) + * 3. Header padding (6 bits) + * 4. CRC24 of the header (24 bits) + * 5. Payload (up to 2 ^ 17 - 1 bits) + * + * 0 1 2 3 + * 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | Payload Length |C| | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * CRC24 of Header | | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + + * | | + * + + + * | Payload | + * + + + * | | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + */ +final class FrameDecoderUnprotected extends FrameDecoderWith8bHeader +{ + FrameDecoderUnprotected(BufferPoolAllocator allocator) + { + super(allocator); + } + + public static FrameDecoderUnprotected create(BufferPoolAllocator allocator) + { + return new FrameDecoderUnprotected(allocator); + } + + final long readHeader(ByteBuffer frame, int begin) + { + return readHeader6b(frame, begin); + } + + final CorruptFrame verifyHeader(long header6b) + { + return verifyHeader6b(header6b); + } + + final int frameLength(long header6b) + { + return payloadLength(header6b) + HEADER_LENGTH; + } + + final Frame unpackFrame(ShareableBytes bytes, int begin, int end, long header6b) + { + boolean isSelfContained = isSelfContained(header6b); + return new IntactFrame(isSelfContained, bytes.slice(begin + HEADER_LENGTH, end)); + } + + void decode(Collection into, ShareableBytes bytes) + { + decode(into, bytes, HEADER_LENGTH); + } + + void addLastTo(ChannelPipeline pipeline) + { + pipeline.addLast("frameDecoderUnprotected", this); + } +} diff --git a/src/java/org/apache/cassandra/net/FrameDecoderWith8bHeader.java b/src/java/org/apache/cassandra/net/FrameDecoderWith8bHeader.java new file mode 100644 index 000000000000..ed87d8272f0a --- /dev/null +++ b/src/java/org/apache/cassandra/net/FrameDecoderWith8bHeader.java @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.cassandra.net; + +import java.nio.ByteBuffer; +import java.util.Collection; + +import net.nicoulaj.compilecommand.annotations.Inline; + +/** + * An abstract frame decoder for frames utilising a fixed length header of 8 bytes or smaller. + * Implements a generic frame decode method, that is backed by the four abstract methods + * (three of which simply decode and verify the header as a long). + * + * Implementors are expected to declare their implementation methods final, and an outer decode + * method implemented to invoke this class' {@link #decode}, so that it may be inlined with the + * abstract method implementations then inlined into it. + */ +abstract class FrameDecoderWith8bHeader extends FrameDecoder +{ + FrameDecoderWith8bHeader(BufferPoolAllocator allocator) + { + super(allocator); + } + + /** + * Read a header that is 8 bytes or shorter, without modifying the buffer position. + * If your header is longer than this, you will need to implement your own {@link #decode} + */ + abstract long readHeader(ByteBuffer in, int begin); + /** + * Verify the header, and return an unrecoverable CorruptFrame if it is corrupted + * @return null or CorruptFrame.unrecoverable + */ + abstract CorruptFrame verifyHeader(long header); + + /** + * Calculate the full frame length from info provided by the header, including the length of the header and any triler + */ + abstract int frameLength(long header); + + /** + * Extract a frame known to cover the given range. + * If {@code transferOwnership}, the method is responsible for ensuring bytes.release() is invoked at some future point. + */ + abstract Frame unpackFrame(ShareableBytes bytes, int begin, int end, long header); + + /** + * Decode a number of frames using the above abstract method implementations. + * It is expected for this method to be invoked by the implementing class' {@link #decode(Collection, ShareableBytes)} + * so that this implementation will be inlined, and all of the abstract method implementations will also be inlined. + */ + @Inline + protected void decode(Collection into, ShareableBytes newBytes, int headerLength) + { + ByteBuffer in = newBytes.get(); + + try + { + if (stash != null) + { + if (!copyToSize(in, stash, headerLength)) + return; + + long header = readHeader(stash, 0); + CorruptFrame c = verifyHeader(header); + if (c != null) + { + discard(); + into.add(c); + return; + } + + int frameLength = frameLength(header); + stash = ensureCapacity(stash, frameLength); + + if (!copyToSize(in, stash, frameLength)) + return; + + stash.flip(); + ShareableBytes stashed = ShareableBytes.wrap(stash); + stash = null; + + try + { + into.add(unpackFrame(stashed, 0, frameLength, header)); + } + finally + { + stashed.release(); + } + } + + int begin = in.position(); + int limit = in.limit(); + while (begin < limit) + { + int remaining = limit - begin; + if (remaining < headerLength) + { + stash(newBytes, headerLength, begin, remaining); + return; + } + + long header = readHeader(in, begin); + CorruptFrame c = verifyHeader(header); + if (c != null) + { + into.add(c); + return; + } + + int frameLength = frameLength(header); + if (remaining < frameLength) + { + stash(newBytes, frameLength, begin, remaining); + return; + } + + into.add(unpackFrame(newBytes, begin, begin + frameLength, header)); + begin += frameLength; + } + } + finally + { + newBytes.release(); + } + } +} diff --git a/src/java/org/apache/cassandra/net/FrameEncoder.java b/src/java/org/apache/cassandra/net/FrameEncoder.java new file mode 100644 index 000000000000..d9df1666b785 --- /dev/null +++ b/src/java/org/apache/cassandra/net/FrameEncoder.java @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.cassandra.net; + +import java.nio.ByteBuffer; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelOutboundHandlerAdapter; +import io.netty.channel.ChannelPromise; +import org.apache.cassandra.io.compress.BufferType; +import org.apache.cassandra.utils.memory.BufferPool; + +abstract class FrameEncoder extends ChannelOutboundHandlerAdapter +{ + /** + * An abstraction useful for transparently allocating buffers that can be written to upstream + * of the {@code FrameEncoder} without knowledge of the encoder's frame layout, while ensuring + * enough space to write the remainder of the frame's contents is reserved. + */ + static class Payload + { + // isSelfContained is a flag in the Frame API, indicating if the contents consists of only complete messages + private boolean isSelfContained; + // the buffer to write to + final ByteBuffer buffer; + // the number of header bytes to reserve + final int headerLength; + // the number of trailer bytes to reserve + final int trailerLength; + // an API-misuse detector + private boolean isFinished = false; + + Payload(boolean isSelfContained, int payloadCapacity) + { + this(isSelfContained, payloadCapacity, 0, 0); + } + + Payload(boolean isSelfContained, int payloadCapacity, int headerLength, int trailerLength) + { + this.isSelfContained = isSelfContained; + this.headerLength = headerLength; + this.trailerLength = trailerLength; + + buffer = BufferPool.getAtLeast(payloadCapacity + headerLength + trailerLength, BufferType.OFF_HEAP); + assert buffer.capacity() >= payloadCapacity + headerLength + trailerLength; + buffer.position(headerLength); + buffer.limit(buffer.capacity() - trailerLength); + } + + void setSelfContained(boolean isSelfContained) + { + this.isSelfContained = isSelfContained; + } + + // do not invoke after finish() + boolean isEmpty() + { + assert !isFinished; + return buffer.position() == headerLength; + } + + // do not invoke after finish() + int length() + { + assert !isFinished; + return buffer.position() - headerLength; + } + + // do not invoke after finish() + int remaining() + { + assert !isFinished; + return buffer.remaining(); + } + + // do not invoke after finish() + void trim(int length) + { + assert !isFinished; + buffer.position(headerLength + length); + } + + // may not be written to or queried, after this is invoked; must be passed straight to an encoder (or release called) + void finish() + { + assert !isFinished; + isFinished = true; + buffer.limit(buffer.position() + trailerLength); + buffer.position(0); + BufferPool.putUnusedPortion(buffer); + } + + void release() + { + BufferPool.put(buffer); + } + } + + interface PayloadAllocator + { + public static final PayloadAllocator simple = Payload::new; + Payload allocate(boolean isSelfContained, int capacity); + } + + PayloadAllocator allocator() + { + return PayloadAllocator.simple; + } + + /** + * Takes ownership of the lifetime of the provided buffer, which can be assumed to be managed by BufferPool + */ + abstract ByteBuf encode(boolean isSelfContained, ByteBuffer buffer); + + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception + { + if (!(msg instanceof Payload)) + throw new IllegalStateException("Unexpected type: " + msg); + + Payload payload = (Payload) msg; + ByteBuf write = encode(payload.isSelfContained, payload.buffer); + ctx.write(write, promise); + } +} diff --git a/src/java/org/apache/cassandra/net/FrameEncoderCrc.java b/src/java/org/apache/cassandra/net/FrameEncoderCrc.java new file mode 100644 index 000000000000..2d07d6d1cbc0 --- /dev/null +++ b/src/java/org/apache/cassandra/net/FrameEncoderCrc.java @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.cassandra.net; + +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.zip.CRC32; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelHandler; +import org.apache.cassandra.utils.memory.BufferPool; + +import static org.apache.cassandra.net.Crc.*; + +/** + * Please see {@link FrameDecoderCrc} for description of the framing produced by this encoder. + */ +@ChannelHandler.Sharable +class FrameEncoderCrc extends FrameEncoder +{ + static final int HEADER_LENGTH = 6; + private static final int TRAILER_LENGTH = 4; + static final int HEADER_AND_TRAILER_LENGTH = 10; + + static final FrameEncoderCrc instance = new FrameEncoderCrc(); + static final PayloadAllocator allocator = (isSelfContained, capacity) -> + new Payload(isSelfContained, capacity, HEADER_LENGTH, TRAILER_LENGTH); + + PayloadAllocator allocator() + { + return allocator; + } + + static void writeHeader(ByteBuffer frame, boolean isSelfContained, int dataLength) + { + int header3b = dataLength; + if (isSelfContained) + header3b |= 1 << 17; + int crc = crc24(header3b, 3); + + put3b(frame, 0, header3b); + put3b(frame, 3, crc); + } + + private static void put3b(ByteBuffer frame, int index, int put3b) + { + frame.put(index , (byte) put3b ); + frame.put(index + 1, (byte)(put3b >>> 8) ); + frame.put(index + 2, (byte)(put3b >>> 16)); + } + + ByteBuf encode(boolean isSelfContained, ByteBuffer frame) + { + try + { + int frameLength = frame.remaining(); + int dataLength = frameLength - HEADER_AND_TRAILER_LENGTH; + if (dataLength >= 1 << 17) + throw new IllegalArgumentException("Maximum payload size is 128KiB"); + + writeHeader(frame, isSelfContained, dataLength); + + CRC32 crc = crc32(); + frame.position(HEADER_LENGTH); + frame.limit(dataLength + HEADER_LENGTH); + crc.update(frame); + + int frameCrc = (int) crc.getValue(); + if (frame.order() == ByteOrder.BIG_ENDIAN) + frameCrc = Integer.reverseBytes(frameCrc); + + frame.limit(frameLength); + frame.putInt(frameLength - TRAILER_LENGTH, frameCrc); + frame.position(0); + return GlobalBufferPoolAllocator.wrap(frame); + } + catch (Throwable t) + { + BufferPool.put(frame); + throw t; + } + } +} diff --git a/src/java/org/apache/cassandra/net/FrameEncoderLZ4.java b/src/java/org/apache/cassandra/net/FrameEncoderLZ4.java new file mode 100644 index 000000000000..12351ce887ba --- /dev/null +++ b/src/java/org/apache/cassandra/net/FrameEncoderLZ4.java @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.cassandra.net; + +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.zip.CRC32; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelHandler; +import net.jpountz.lz4.LZ4Compressor; +import net.jpountz.lz4.LZ4Factory; +import org.apache.cassandra.io.compress.BufferType; +import org.apache.cassandra.utils.ByteBufferUtil; +import org.apache.cassandra.utils.memory.BufferPool; + +import static org.apache.cassandra.net.Crc.*; + +/** + * Please see {@link FrameDecoderLZ4} for description of the framing produced by this encoder. + */ +@ChannelHandler.Sharable +class FrameEncoderLZ4 extends FrameEncoder +{ + static final FrameEncoderLZ4 fastInstance = new FrameEncoderLZ4(LZ4Factory.fastestInstance().fastCompressor()); + + private final LZ4Compressor compressor; + + private FrameEncoderLZ4(LZ4Compressor compressor) + { + this.compressor = compressor; + } + + private static final int HEADER_LENGTH = 8; + static final int HEADER_AND_TRAILER_LENGTH = 12; + + private static void writeHeader(ByteBuffer frame, boolean isSelfContained, long compressedLength, long uncompressedLength) + { + long header5b = compressedLength | (uncompressedLength << 17); + if (isSelfContained) + header5b |= 1L << 34; + + long crc = crc24(header5b, 5); + + long header8b = header5b | (crc << 40); + if (frame.order() == ByteOrder.BIG_ENDIAN) + header8b = Long.reverseBytes(header8b); + + frame.putLong(0, header8b); + } + + public ByteBuf encode(boolean isSelfContained, ByteBuffer in) + { + ByteBuffer frame = null; + try + { + int uncompressedLength = in.remaining(); + if (uncompressedLength >= 1 << 17) + throw new IllegalArgumentException("Maximum uncompressed payload size is 128KiB"); + + int maxOutputLength = compressor.maxCompressedLength(uncompressedLength); + frame = BufferPool.getAtLeast(HEADER_AND_TRAILER_LENGTH + maxOutputLength, BufferType.OFF_HEAP); + + int compressedLength = compressor.compress(in, in.position(), uncompressedLength, frame, HEADER_LENGTH, maxOutputLength); + + if (compressedLength >= uncompressedLength) + { + ByteBufferUtil.copyBytes(in, in.position(), frame, HEADER_LENGTH, uncompressedLength); + compressedLength = uncompressedLength; + uncompressedLength = 0; + } + + writeHeader(frame, isSelfContained, compressedLength, uncompressedLength); + + CRC32 crc = crc32(); + frame.position(HEADER_LENGTH); + frame.limit(compressedLength + HEADER_LENGTH); + crc.update(frame); + + int frameCrc = (int) crc.getValue(); + if (frame.order() == ByteOrder.BIG_ENDIAN) + frameCrc = Integer.reverseBytes(frameCrc); + int frameLength = compressedLength + HEADER_AND_TRAILER_LENGTH; + + frame.limit(frameLength); + frame.putInt(frameCrc); + frame.position(0); + + BufferPool.putUnusedPortion(frame); + return GlobalBufferPoolAllocator.wrap(frame); + } + catch (Throwable t) + { + if (frame != null) + BufferPool.put(frame); + throw t; + } + finally + { + BufferPool.put(in); + } + } +} diff --git a/src/java/org/apache/cassandra/net/FrameEncoderLegacy.java b/src/java/org/apache/cassandra/net/FrameEncoderLegacy.java new file mode 100644 index 000000000000..8bfd2678ad1b --- /dev/null +++ b/src/java/org/apache/cassandra/net/FrameEncoderLegacy.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.cassandra.net; + +import java.nio.ByteBuffer; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelHandler; + +/** + * A no-op frame encoder: legacy format doesn't support framing. Instead, the byte stream + * contains messages, serialized back to back. + */ +@ChannelHandler.Sharable +class FrameEncoderLegacy extends FrameEncoder +{ + static final FrameEncoderLegacy instance = new FrameEncoderLegacy(); + + ByteBuf encode(boolean isSelfContained, ByteBuffer buffer) + { + return GlobalBufferPoolAllocator.wrap(buffer); + } +} diff --git a/src/java/org/apache/cassandra/net/FrameEncoderLegacyLZ4.java b/src/java/org/apache/cassandra/net/FrameEncoderLegacyLZ4.java new file mode 100644 index 000000000000..3b29ecb7ae56 --- /dev/null +++ b/src/java/org/apache/cassandra/net/FrameEncoderLegacyLZ4.java @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.cassandra.net; + +import java.nio.ByteBuffer; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelHandler; +import net.jpountz.lz4.LZ4Compressor; +import net.jpountz.lz4.LZ4Factory; +import net.jpountz.xxhash.XXHash32; +import net.jpountz.xxhash.XXHashFactory; +import org.apache.cassandra.io.compress.BufferType; +import org.apache.cassandra.utils.ByteBufferUtil; +import org.apache.cassandra.utils.memory.BufferPool; + +import static java.lang.Integer.reverseBytes; +import static java.lang.Math.min; +import static org.apache.cassandra.net.LegacyLZ4Constants.*; + +/** + * LZ4 {@link FrameEncoder} implementation for compressed legacy (3.0, 3.11) connections. + * + * Netty's provided implementation - {@link io.netty.handler.codec.compression.Lz4FrameEncoder} couldn't be reused + * for two reasons: + * 1. It notifies flushes as successful when they may not be, by flushing an empty buffer ahead + * of the compressed buffer + * 2. It has very poor performance when coupled with xxHash, which we use for legacy connections - + * allocating a single-byte array and making a JNI call for every byte of the payload + * + * Please see {@link FrameDecoderLegacyLZ4} for the description of the on-wire format of the LZ4 blocks + * used by this encoder. + */ +@ChannelHandler.Sharable +class FrameEncoderLegacyLZ4 extends FrameEncoder +{ + static final FrameEncoderLegacyLZ4 instance = + new FrameEncoderLegacyLZ4(XXHashFactory.fastestInstance().hash32(), + LZ4Factory.fastestInstance().fastCompressor()); + + private final XXHash32 xxhash; + private final LZ4Compressor compressor; + + private FrameEncoderLegacyLZ4(XXHash32 xxhash, LZ4Compressor compressor) + { + this.xxhash = xxhash; + this.compressor = compressor; + } + + @Override + ByteBuf encode(boolean isSelfContained, ByteBuffer payload) + { + ByteBuffer frame = null; + try + { + frame = BufferPool.getAtLeast(calculateMaxFrameLength(payload), BufferType.OFF_HEAP); + + int frameOffset = 0; + int payloadOffset = 0; + + int payloadLength = payload.remaining(); + while (payloadOffset < payloadLength) + { + int blockLength = min(DEFAULT_BLOCK_LENGTH, payloadLength - payloadOffset); + frameOffset += compressBlock(frame, frameOffset, payload, payloadOffset, blockLength); + payloadOffset += blockLength; + } + + frame.limit(frameOffset); + BufferPool.putUnusedPortion(frame); + + return GlobalBufferPoolAllocator.wrap(frame); + } + catch (Throwable t) + { + if (null != frame) + BufferPool.put(frame); + throw t; + } + finally + { + BufferPool.put(payload); + } + } + + private int compressBlock(ByteBuffer frame, int frameOffset, ByteBuffer payload, int payloadOffset, int blockLength) + { + int frameBytesRemaining = frame.limit() - (frameOffset + HEADER_LENGTH); + int compressedLength = compressor.compress(payload, payloadOffset, blockLength, frame, frameOffset + HEADER_LENGTH, frameBytesRemaining); + if (compressedLength >= blockLength) + { + ByteBufferUtil.copyBytes(payload, payloadOffset, frame, frameOffset + HEADER_LENGTH, blockLength); + compressedLength = blockLength; + } + int checksum = xxhash.hash(payload, payloadOffset, blockLength, XXHASH_SEED) & XXHASH_MASK; + writeHeader(frame, frameOffset, compressedLength, blockLength, checksum); + return HEADER_LENGTH + compressedLength; + } + + private static final byte TOKEN_NON_COMPRESSED = 0x15; + private static final byte TOKEN_COMPRESSED = 0x25; + + private static void writeHeader(ByteBuffer frame, int frameOffset, int compressedLength, int uncompressedLength, int checksum) + { + byte token = compressedLength == uncompressedLength + ? TOKEN_NON_COMPRESSED + : TOKEN_COMPRESSED; + + frame.putLong(frameOffset + MAGIC_NUMBER_OFFSET, MAGIC_NUMBER ); + frame.put (frameOffset + TOKEN_OFFSET, token ); + frame.putInt (frameOffset + COMPRESSED_LENGTH_OFFSET, reverseBytes(compressedLength) ); + frame.putInt (frameOffset + UNCOMPRESSED_LENGTH_OFFSET, reverseBytes(uncompressedLength)); + frame.putInt (frameOffset + CHECKSUM_OFFSET, reverseBytes(checksum) ); + } + + private int calculateMaxFrameLength(ByteBuffer payload) + { + int payloadLength = payload.remaining(); + int blockCount = payloadLength / DEFAULT_BLOCK_LENGTH + (payloadLength % DEFAULT_BLOCK_LENGTH != 0 ? 1 : 0); + return compressor.maxCompressedLength(payloadLength) + HEADER_LENGTH * blockCount; + } +} diff --git a/src/java/org/apache/cassandra/net/FrameEncoderUnprotected.java b/src/java/org/apache/cassandra/net/FrameEncoderUnprotected.java new file mode 100644 index 000000000000..3bca41c25532 --- /dev/null +++ b/src/java/org/apache/cassandra/net/FrameEncoderUnprotected.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.cassandra.net; + +import java.nio.ByteBuffer; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelHandler; +import org.apache.cassandra.utils.memory.BufferPool; + +import static org.apache.cassandra.net.FrameEncoderCrc.HEADER_LENGTH; +import static org.apache.cassandra.net.FrameEncoderCrc.writeHeader; + +/** + * A frame encoder that writes frames, just without any modification or payload protection. + * This is non-standard, and useful for systems that have a trusted transport layer that want + * to avoid incurring the (very low) cost of computing a CRC. + * + * Please see {@link FrameDecoderUnprotected} for description of the framing produced by this encoder. + */ +@ChannelHandler.Sharable +class FrameEncoderUnprotected extends FrameEncoder +{ + static final FrameEncoderUnprotected instance = new FrameEncoderUnprotected(); + static final PayloadAllocator allocator = (isSelfContained, capacity) -> + new Payload(isSelfContained, capacity, HEADER_LENGTH, 0); + + PayloadAllocator allocator() + { + return allocator; + } + + ByteBuf encode(boolean isSelfContained, ByteBuffer frame) + { + try + { + int frameLength = frame.remaining(); + int dataLength = frameLength - HEADER_LENGTH; + if (dataLength >= 1 << 17) + throw new IllegalArgumentException("Maximum uncompressed payload size is 128KiB"); + + writeHeader(frame, isSelfContained, dataLength); + return GlobalBufferPoolAllocator.wrap(frame); + } + catch (Throwable t) + { + BufferPool.put(frame); + throw t; + } + } +} diff --git a/src/java/org/apache/cassandra/net/FutureCombiner.java b/src/java/org/apache/cassandra/net/FutureCombiner.java new file mode 100644 index 000000000000..dd094bdcfdb6 --- /dev/null +++ b/src/java/org/apache/cassandra/net/FutureCombiner.java @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.cassandra.net; + +import java.util.Collection; +import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; +import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; + +import io.netty.util.concurrent.Future; +import io.netty.util.concurrent.GenericFutureListener; +import io.netty.util.concurrent.GlobalEventExecutor; +import io.netty.util.concurrent.Promise; + +/** + * Netty's PromiseCombiner is not threadsafe, and we combine futures from multiple event executors. + * + * This class groups a number of Future into a single logical Future, by registering a listener to each that + * decrements a shared counter; if any of them fail, the FutureCombiner is completed with the first cause, + * but in all scenario only completes when all underlying future have completed (exceptionally or otherwise) + * + * This Future is always uncancellable. + * + * We extend FutureDelegate, and simply provide it an uncancellable Promise that will be completed by the listeners + * registered to the input futures. + */ +class FutureCombiner extends FutureDelegate +{ + private volatile boolean failed; + + private volatile Throwable firstCause; + private static final AtomicReferenceFieldUpdater firstCauseUpdater = + AtomicReferenceFieldUpdater.newUpdater(FutureCombiner.class, Throwable.class, "firstCause"); + + private volatile int waitingOn; + private static final AtomicIntegerFieldUpdater waitingOnUpdater = + AtomicIntegerFieldUpdater.newUpdater(FutureCombiner.class, "waitingOn"); + + FutureCombiner(Collection> combine) + { + this(AsyncPromise.uncancellable(GlobalEventExecutor.INSTANCE), combine); + } + + private FutureCombiner(Promise combined, Collection> combine) + { + super(combined); + + if (0 == (waitingOn = combine.size())) + combined.trySuccess(null); + + GenericFutureListener> listener = result -> + { + if (!result.isSuccess()) + { + firstCauseUpdater.compareAndSet(this, null, result.cause()); + failed = true; + } + + if (0 == waitingOnUpdater.decrementAndGet(this)) + { + if (failed) + combined.tryFailure(firstCause); + else + combined.trySuccess(null); + } + }; + + for (Future future : combine) + future.addListener(listener); + } +} diff --git a/src/java/org/apache/cassandra/net/FutureDelegate.java b/src/java/org/apache/cassandra/net/FutureDelegate.java new file mode 100644 index 000000000000..f04a43275fd0 --- /dev/null +++ b/src/java/org/apache/cassandra/net/FutureDelegate.java @@ -0,0 +1,145 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.cassandra.net; + +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +import io.netty.util.concurrent.Future; +import io.netty.util.concurrent.GenericFutureListener; + +/** + * A delegating future, that we can extend to provide subtly modified behaviour. + * + * See {@link FutureCombiner} and {@link FutureResult} + */ +class FutureDelegate implements Future +{ + final Future delegate; + + FutureDelegate(Future delegate) + { + this.delegate = delegate; + } + + public boolean isSuccess() + { + return delegate.isSuccess(); + } + + public boolean isCancellable() + { + return delegate.isCancellable(); + } + + public Throwable cause() + { + return delegate.cause(); + } + + public Future addListener(GenericFutureListener> genericFutureListener) + { + return delegate.addListener(genericFutureListener); + } + + public Future addListeners(GenericFutureListener>... genericFutureListeners) + { + return delegate.addListeners(genericFutureListeners); + } + + public Future removeListener(GenericFutureListener> genericFutureListener) + { + return delegate.removeListener(genericFutureListener); + } + + public Future removeListeners(GenericFutureListener>... genericFutureListeners) + { + return delegate.removeListeners(genericFutureListeners); + } + + public Future sync() throws InterruptedException + { + return delegate.sync(); + } + + public Future syncUninterruptibly() + { + return delegate.syncUninterruptibly(); + } + + public Future await() throws InterruptedException + { + return delegate.await(); + } + + public Future awaitUninterruptibly() + { + return delegate.awaitUninterruptibly(); + } + + public boolean await(long l, TimeUnit timeUnit) throws InterruptedException + { + return delegate.await(l, timeUnit); + } + + public boolean await(long l) throws InterruptedException + { + return delegate.await(l); + } + + public boolean awaitUninterruptibly(long l, TimeUnit timeUnit) + { + return delegate.awaitUninterruptibly(l, timeUnit); + } + + public boolean awaitUninterruptibly(long l) + { + return delegate.awaitUninterruptibly(l); + } + + public V getNow() + { + return delegate.getNow(); + } + + public boolean cancel(boolean b) + { + return delegate.cancel(b); + } + + public boolean isCancelled() + { + return delegate.isCancelled(); + } + + public boolean isDone() + { + return delegate.isDone(); + } + + public V get() throws InterruptedException, ExecutionException + { + return delegate.get(); + } + + public V get(long timeout, TimeUnit unit) throws InterruptedException, ExecutionException, TimeoutException + { + return delegate.get(timeout, unit); + } +} diff --git a/src/java/org/apache/cassandra/net/async/ByteBufDataInputPlus.java b/src/java/org/apache/cassandra/net/FutureResult.java similarity index 54% rename from src/java/org/apache/cassandra/net/async/ByteBufDataInputPlus.java rename to src/java/org/apache/cassandra/net/FutureResult.java index e0be7151b1ef..8d43dbe39a78 100644 --- a/src/java/org/apache/cassandra/net/async/ByteBufDataInputPlus.java +++ b/src/java/org/apache/cassandra/net/FutureResult.java @@ -15,37 +15,33 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +package org.apache.cassandra.net; -package org.apache.cassandra.net.async; +import io.netty.util.concurrent.Future; -import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufInputStream; -import org.apache.cassandra.io.util.DataInputPlus; - -import java.io.IOException; - -public class ByteBufDataInputPlus extends ByteBufInputStream implements DataInputPlus +/** + * An abstraction for yielding a result performed by an asynchronous task, + * for whom we may wish to offer cancellation, + * but no other access to the underlying task + */ +class FutureResult extends FutureDelegate { + private final Future tryCancel; + /** - * The parent class does not expose the buffer to derived classes, so we need - * to stash a reference here so it can be exposed via {@link #buffer()}. + * @param result the Future that will be completed by {@link #cancel} + * @param cancel the Future that is performing the work, and to whom any cancellation attempts will be proxied */ - private final ByteBuf buf; - - public ByteBufDataInputPlus(ByteBuf buffer) - { - super(buffer); - this.buf = buffer; - } - - public ByteBuf buffer() + FutureResult(Future result, Future cancel) { - return buf; + super(result); + this.tryCancel = cancel; } @Override - public String readUTF() throws IOException + public boolean cancel(boolean b) { - return DataInputStreamPlus.readUTF(this); + tryCancel.cancel(true); + return delegate.cancel(b); } } diff --git a/test/unit/org/apache/cassandra/net/async/TestAuthenticator.java b/src/java/org/apache/cassandra/net/GlobalBufferPoolAllocator.java similarity index 60% rename from test/unit/org/apache/cassandra/net/async/TestAuthenticator.java rename to src/java/org/apache/cassandra/net/GlobalBufferPoolAllocator.java index 3107f2aba0bc..66cbc9e1ebb0 100644 --- a/test/unit/org/apache/cassandra/net/async/TestAuthenticator.java +++ b/src/java/org/apache/cassandra/net/GlobalBufferPoolAllocator.java @@ -15,28 +15,27 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +package org.apache.cassandra.net; -package org.apache.cassandra.net.async; +import java.nio.ByteBuffer; -import java.net.InetAddress; +import io.netty.buffer.ByteBuf; +import org.apache.cassandra.utils.memory.BufferPool; -import org.apache.cassandra.auth.IInternodeAuthenticator; -import org.apache.cassandra.exceptions.ConfigurationException; - -class TestAuthenticator implements IInternodeAuthenticator +/** + * Primary {@link ByteBuf} / {@link ByteBuffer} allocator - using the global {@link BufferPool}. + */ +class GlobalBufferPoolAllocator extends BufferPoolAllocator { - private final boolean authAll; + static final GlobalBufferPoolAllocator instance = new GlobalBufferPoolAllocator(); - TestAuthenticator(boolean authAll) + private GlobalBufferPoolAllocator() { - this.authAll = authAll; + super(); } - public boolean authenticate(InetAddress remoteAddress, int remotePort) + static ByteBuf wrap(ByteBuffer buffer) { - return authAll; + return new Wrapped(instance, buffer); } - - public void validateConfiguration() throws ConfigurationException - { } } diff --git a/src/java/org/apache/cassandra/net/HandshakeProtocol.java b/src/java/org/apache/cassandra/net/HandshakeProtocol.java new file mode 100644 index 000000000000..47d0ec6dffa3 --- /dev/null +++ b/src/java/org/apache/cassandra/net/HandshakeProtocol.java @@ -0,0 +1,414 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.cassandra.net; + +import java.io.EOFException; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.Objects; + +import com.google.common.annotations.VisibleForTesting; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import org.apache.cassandra.config.DatabaseDescriptor; +import org.apache.cassandra.io.compress.BufferType; +import org.apache.cassandra.io.util.DataInputBuffer; +import org.apache.cassandra.io.util.DataInputPlus; +import org.apache.cassandra.io.util.DataOutputBufferFixed; +import org.apache.cassandra.locator.InetAddressAndPort; +import org.apache.cassandra.utils.memory.BufferPool; + +import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static org.apache.cassandra.locator.InetAddressAndPort.Serializer.inetAddressAndPortSerializer; +import static org.apache.cassandra.net.MessagingService.VERSION_30; +import static org.apache.cassandra.net.MessagingService.VERSION_40; +import static org.apache.cassandra.net.Message.validateLegacyProtocolMagic; +import static org.apache.cassandra.net.Crc.*; +import static org.apache.cassandra.net.Crc.computeCrc32; +import static org.apache.cassandra.net.OutboundConnectionSettings.*; + +/** + * Messages for the handshake phase of the internode protocol. + * + * The modern handshake is composed of 2 messages: Initiate and Accept + *

+ * The legacy handshake is composed of 3 messages, the first being sent by the initiator of the connection. The other + * side then answer with the 2nd message. At that point, if a version mismatch is detected by the connection initiator, + * it will simply disconnect and reconnect with a more appropriate version. But if the version is acceptable, the connection + * initiator sends the third message of the protocol, after which it considers the connection ready. + */ +class HandshakeProtocol +{ + static final long TIMEOUT_MILLIS = 3 * DatabaseDescriptor.getRpcTimeout(MILLISECONDS); + + /** + * The initial message sent when a node creates a new connection to a remote peer. This message contains: + * 1) the {@link Message#PROTOCOL_MAGIC} number (4 bytes). + * 2) the connection flags (4 bytes), which encodes: + * - the version the initiator thinks should be used for the connection (in practice, either the initiator + * version if it's the first time we connect to that remote since startup, or the last version known for that + * peer otherwise). + * - the "mode" of the connection: whether it is for streaming or for messaging. + * - whether compression should be used or not (if it is, compression is enabled _after_ the last message of the + * handshake has been sent). + * 3) the connection initiator's broadcast address + * 4) a CRC protecting the message from corruption + *

+ * More precisely, connection flags: + *

+     * {@code
+     *                      1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 3 3
+     *  0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+     * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+     * |C C C M C      |    REQUEST    |      MIN      |      MAX      |
+     * |A A M O R      |    VERSION    |   SUPPORTED   |   SUPPORTED   |
+     * |T T P D C      |  (DEPRECATED) |    VERSION    |    VERSION    |
+     * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+     * }
+     * 
+ * CAT - QOS category, 2 bits: SMALL, LARGE, URGENT, or LEGACY (unset) + * CMP - compression enabled bit + * MOD - connection mode; if the bit is on, the connection is for streaming; if the bit is off, it is for inter-node messaging. + * CRC - crc enabled bit + * VERSION - {@link org.apache.cassandra.net.MessagingService#current_version} + */ + static class Initiate + { + /** Contains the PROTOCOL_MAGIC (int) and the flags (int). */ + private static final int MIN_LENGTH = 8; + private static final int MAX_LENGTH = 12 + InetAddressAndPort.Serializer.MAXIMUM_SIZE; + + @Deprecated // this is ignored by post40 nodes, i.e. if maxMessagingVersion is set + final int requestMessagingVersion; + // the messagingVersion bounds the sender will accept to initiate a connection; + // if the remote peer supports any, the newest supported version will be selected; otherwise the nearest supported version + final AcceptVersions acceptVersions; + final ConnectionType type; + final Framing framing; + final InetAddressAndPort from; + + Initiate(int requestMessagingVersion, AcceptVersions acceptVersions, ConnectionType type, Framing framing, InetAddressAndPort from) + { + this.requestMessagingVersion = requestMessagingVersion; + this.acceptVersions = acceptVersions; + this.type = type; + this.framing = framing; + this.from = from; + } + + @VisibleForTesting + int encodeFlags() + { + int flags = 0; + if (type.isMessaging()) + flags |= type.twoBitID(); + if (type.isStreaming()) + flags |= 1 << 3; + + // framing id is split over 2nd and 4th bits, for backwards compatibility + flags |= ((framing.id & 1) << 2) | ((framing.id & 2) << 3); + flags |= (requestMessagingVersion << 8); + + if (requestMessagingVersion < VERSION_40 || acceptVersions.max < VERSION_40) + return flags; // for testing, permit serializing as though we are pre40 + + flags |= (acceptVersions.min << 16); + flags |= (acceptVersions.max << 24); + return flags; + } + + ByteBuf encode() + { + ByteBuffer buffer = BufferPool.get(MAX_LENGTH, BufferType.OFF_HEAP); + try (DataOutputBufferFixed out = new DataOutputBufferFixed(buffer)) + { + out.writeInt(Message.PROTOCOL_MAGIC); + out.writeInt(encodeFlags()); + + if (requestMessagingVersion >= VERSION_40 && acceptVersions.max >= VERSION_40) + { + inetAddressAndPortSerializer.serialize(from, out, requestMessagingVersion); + out.writeInt(computeCrc32(buffer, 0, buffer.position())); + } + buffer.flip(); + return GlobalBufferPoolAllocator.wrap(buffer); + } + catch (IOException e) + { + throw new IllegalStateException(e); + } + } + + static Initiate maybeDecode(ByteBuf buf) throws IOException + { + if (buf.readableBytes() < MIN_LENGTH) + return null; + + ByteBuffer nio = buf.nioBuffer(); + int start = nio.position(); + try (DataInputBuffer in = new DataInputBuffer(nio, false)) + { + validateLegacyProtocolMagic(in.readInt()); + int flags = in.readInt(); + + int requestedMessagingVersion = getBits(flags, 8, 8); + int minMessagingVersion = getBits(flags, 16, 8); + int maxMessagingVersion = getBits(flags, 24, 8); + int framingBits = getBits(flags, 2, 1) | (getBits(flags, 4, 1) << 1); + Framing framing = Framing.forId(framingBits); + + boolean isStream = getBits(flags, 3, 1) == 1; + + ConnectionType type = isStream + ? ConnectionType.STREAMING + : ConnectionType.fromId(getBits(flags, 0, 2)); + + InetAddressAndPort from = null; + + if (requestedMessagingVersion >= VERSION_40 && maxMessagingVersion >= MessagingService.VERSION_40) + { + from = inetAddressAndPortSerializer.deserialize(in, requestedMessagingVersion); + + int computed = computeCrc32(nio, start, nio.position()); + int read = in.readInt(); + if (read != computed) + throw new InvalidCrc(read, computed); + } + + buf.skipBytes(nio.position() - start); + return new Initiate(requestedMessagingVersion, + minMessagingVersion == 0 && maxMessagingVersion == 0 + ? null : new AcceptVersions(minMessagingVersion, maxMessagingVersion), + type, framing, from); + + } + catch (EOFException e) + { + return null; + } + } + + @VisibleForTesting + @Override + public boolean equals(Object other) + { + if (!(other instanceof Initiate)) + return false; + + Initiate that = (Initiate)other; + return this.type == that.type + && this.framing == that.framing + && this.requestMessagingVersion == that.requestMessagingVersion + && Objects.equals(this.acceptVersions, that.acceptVersions); + } + + @Override + public String toString() + { + return String.format("Initiate(request: %d, min: %d, max: %d, type: %s, framing: %b, from: %s)", + requestMessagingVersion, + acceptVersions == null ? requestMessagingVersion : acceptVersions.min, + acceptVersions == null ? requestMessagingVersion : acceptVersions.max, + type, framing, from); + } + } + + + /** + * The second message of the handshake, sent by the node receiving the {@link Initiate} back to the + * connection initiator. + * + * This message contains + * 1) the messaging version of the peer sending this message + * 2) the negotiated messaging version if one could be accepted by both peers, + * or if not the closest version that this peer could support to the ones requested + * 3) a CRC protectingn the integrity of the message + * + * Note that the pre40 equivalent of this message contains ONLY the messaging version of the peer. + */ + static class Accept + { + /** The messaging version sent by the receiving peer (int). */ + private static final int MAX_LENGTH = 12; + + final int useMessagingVersion; + final int maxMessagingVersion; + + Accept(int useMessagingVersion, int maxMessagingVersion) + { + this.useMessagingVersion = useMessagingVersion; + this.maxMessagingVersion = maxMessagingVersion; + } + + ByteBuf encode(ByteBufAllocator allocator) + { + ByteBuf buffer = allocator.directBuffer(MAX_LENGTH); + buffer.clear(); + buffer.writeInt(maxMessagingVersion); + buffer.writeInt(useMessagingVersion); + buffer.writeInt(computeCrc32(buffer, 0, 8)); + return buffer; + } + + /** + * Respond to pre40 nodes only with our current messagingVersion + */ + static ByteBuf respondPre40(int messagingVersion, ByteBufAllocator allocator) + { + ByteBuf buffer = allocator.directBuffer(4); + buffer.clear(); + buffer.writeInt(messagingVersion); + return buffer; + } + + static Accept maybeDecode(ByteBuf in, int handshakeMessagingVersion) throws InvalidCrc + { + int readerIndex = in.readerIndex(); + if (in.readableBytes() < 4) + return null; + int maxMessagingVersion = in.readInt(); + int useMessagingVersion = 0; + + // if the other node is pre-4.0, it will respond only with its maxMessagingVersion + if (maxMessagingVersion < VERSION_40 || handshakeMessagingVersion < VERSION_40) + return new Accept(useMessagingVersion, maxMessagingVersion); + + if (in.readableBytes() < 8) + { + in.readerIndex(readerIndex); + return null; + } + useMessagingVersion = in.readInt(); + + // verify crc + int computed = computeCrc32(in, readerIndex, readerIndex + 8); + int read = in.readInt(); + if (read != computed) + throw new InvalidCrc(read, computed); + + return new Accept(useMessagingVersion, maxMessagingVersion); + } + + @VisibleForTesting + @Override + public boolean equals(Object other) + { + return other instanceof Accept + && this.useMessagingVersion == ((Accept) other).useMessagingVersion + && this.maxMessagingVersion == ((Accept) other).maxMessagingVersion; + } + + @Override + public String toString() + { + return String.format("Accept(use: %d, max: %d)", useMessagingVersion, maxMessagingVersion); + } + } + + /** + * The third message of the handshake, sent by pre40 nodes on reception of {@link Accept}. + * This message contains: + * 1) The connection initiator's {@link org.apache.cassandra.net.MessagingService#current_version} (4 bytes). + * This indicates the max messaging version supported by this node. + * 2) The connection initiator's broadcast address as encoded by {@link InetAddressAndPort.Serializer}. + * This can be either 7 bytes for an IPv4 address, or 19 bytes for an IPv6 one, post40. + * This can be either 5 bytes for an IPv4 address, or 17 bytes for an IPv6 one, pre40. + *

+ * This message concludes the legacy handshake protocol. + */ + static class ConfirmOutboundPre40 + { + private static final int MAX_LENGTH = 4 + InetAddressAndPort.Serializer.MAXIMUM_SIZE; + + final int maxMessagingVersion; + final InetAddressAndPort from; + + ConfirmOutboundPre40(int maxMessagingVersion, InetAddressAndPort from) + { + this.maxMessagingVersion = maxMessagingVersion; + this.from = from; + } + + ByteBuf encode() + { + ByteBuffer buffer = BufferPool.get(MAX_LENGTH, BufferType.OFF_HEAP); + try (DataOutputBufferFixed out = new DataOutputBufferFixed(buffer)) + { + out.writeInt(maxMessagingVersion); + // pre-4.0 nodes should only receive the address, never port, and it's ok to hardcode VERSION_30 + inetAddressAndPortSerializer.serialize(from, out, VERSION_30); + buffer.flip(); + return GlobalBufferPoolAllocator.wrap(buffer); + } + catch (IOException e) + { + throw new IllegalStateException(e); + } + } + + @SuppressWarnings("resource") + static ConfirmOutboundPre40 maybeDecode(ByteBuf in) + { + ByteBuffer nio = in.nioBuffer(); + int start = nio.position(); + DataInputPlus input = new DataInputBuffer(nio, false); + try + { + int version = input.readInt(); + InetAddressAndPort address = inetAddressAndPortSerializer.deserialize(input, version); + in.skipBytes(nio.position() - start); + return new ConfirmOutboundPre40(version, address); + } + catch (EOFException e) + { + // makes the assumption we didn't have enough bytes to deserialize an IPv6 address, + // as we only check the MIN_LENGTH of the buf. + return null; + } + catch (IOException e) + { + throw new IllegalStateException(e); + } + } + + @VisibleForTesting + @Override + public boolean equals(Object other) + { + if (!(other instanceof ConfirmOutboundPre40)) + return false; + + ConfirmOutboundPre40 that = (ConfirmOutboundPre40) other; + return this.maxMessagingVersion == that.maxMessagingVersion + && Objects.equals(this.from, that.from); + } + + @Override + public String toString() + { + return String.format("ConfirmOutboundPre40(maxMessagingVersion: %d; address: %s)", maxMessagingVersion, from); + } + } + + private static int getBits(int packed, int start, int count) + { + return (packed >>> start) & ~(-1 << count); + } + +} diff --git a/src/java/org/apache/cassandra/net/IVerbHandler.java b/src/java/org/apache/cassandra/net/IVerbHandler.java index 0995a68e9aed..ac0efe7359b0 100644 --- a/src/java/org/apache/cassandra/net/IVerbHandler.java +++ b/src/java/org/apache/cassandra/net/IVerbHandler.java @@ -24,7 +24,6 @@ * The concrete implementation of this interface would provide the functionality * for a given verb. */ - public interface IVerbHandler { /** @@ -34,7 +33,6 @@ public interface IVerbHandler * because the implementation may be synchronized. * * @param message - incoming message that needs handling. - * @param id */ - void doVerb(MessageIn message, int id) throws IOException; + void doVerb(Message message) throws IOException; } diff --git a/src/java/org/apache/cassandra/net/InboundConnectionInitiator.java b/src/java/org/apache/cassandra/net/InboundConnectionInitiator.java new file mode 100644 index 000000000000..c390ba4287e2 --- /dev/null +++ b/src/java/org/apache/cassandra/net/InboundConnectionInitiator.java @@ -0,0 +1,497 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.cassandra.net; + +import java.io.IOException; +import java.net.InetSocketAddress; +import java.net.SocketAddress; +import java.util.List; +import java.util.concurrent.Future; +import java.util.function.Consumer; + +import javax.net.ssl.SSLSession; + +import com.google.common.annotations.VisibleForTesting; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import io.netty.bootstrap.ServerBootstrap; +import io.netty.buffer.ByteBuf; +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.ChannelOption; +import io.netty.channel.ChannelPipeline; +import io.netty.channel.group.ChannelGroup; +import io.netty.channel.socket.SocketChannel; +import io.netty.handler.codec.ByteToMessageDecoder; +import io.netty.handler.logging.LogLevel; +import io.netty.handler.logging.LoggingHandler; +import io.netty.handler.ssl.SslContext; +import io.netty.handler.ssl.SslHandler; +import org.apache.cassandra.config.EncryptionOptions; +import org.apache.cassandra.exceptions.ConfigurationException; +import org.apache.cassandra.locator.InetAddressAndPort; +import org.apache.cassandra.net.OutboundConnectionSettings.Framing; +import org.apache.cassandra.security.SSLFactory; +import org.apache.cassandra.streaming.async.StreamingInboundHandler; +import org.apache.cassandra.utils.memory.BufferPool; + +import static java.lang.Math.*; +import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static org.apache.cassandra.net.MessagingService.*; +import static org.apache.cassandra.net.MessagingService.VERSION_40; +import static org.apache.cassandra.net.MessagingService.current_version; +import static org.apache.cassandra.net.MessagingService.minimum_version; +import static org.apache.cassandra.net.SocketFactory.WIRETRACE; +import static org.apache.cassandra.net.SocketFactory.encryptionLogStatement; +import static org.apache.cassandra.net.SocketFactory.newSslHandler; + +public class InboundConnectionInitiator +{ + private static final Logger logger = LoggerFactory.getLogger(InboundConnectionInitiator.class); + + private static class Initializer extends ChannelInitializer + { + private final InboundConnectionSettings settings; + private final ChannelGroup channelGroup; + private final Consumer pipelineInjector; + + Initializer(InboundConnectionSettings settings, ChannelGroup channelGroup, + Consumer pipelineInjector) + { + this.settings = settings; + this.channelGroup = channelGroup; + this.pipelineInjector = pipelineInjector; + } + + @Override + public void initChannel(SocketChannel channel) throws Exception + { + channelGroup.add(channel); + + channel.config().setOption(ChannelOption.ALLOCATOR, GlobalBufferPoolAllocator.instance); + channel.config().setOption(ChannelOption.SO_KEEPALIVE, true); + channel.config().setOption(ChannelOption.SO_REUSEADDR, true); + channel.config().setOption(ChannelOption.TCP_NODELAY, true); // we only send handshake messages; no point ever delaying + + ChannelPipeline pipeline = channel.pipeline(); + + pipelineInjector.accept(pipeline); + + // order of handlers: ssl -> logger -> handshakeHandler + if (settings.encryption.enabled) + { + if (settings.encryption.optional) + { + pipeline.addFirst("ssl", new OptionalSslHandler(settings.encryption)); + } + else + { + SslContext sslContext = SSLFactory.getOrCreateSslContext(settings.encryption, true, SSLFactory.SocketType.SERVER); + InetSocketAddress peer = settings.encryption.require_endpoint_verification ? channel.remoteAddress() : null; + SslHandler sslHandler = newSslHandler(channel, sslContext, peer); + logger.trace("creating inbound netty SslContext: context={}, engine={}", sslContext.getClass().getName(), sslHandler.engine().getClass().getName()); + pipeline.addFirst("ssl", sslHandler); + } + } + + if (WIRETRACE) + pipeline.addLast("logger", new LoggingHandler(LogLevel.INFO)); + + channel.pipeline().addLast("handshake", new Handler(settings)); + + } + } + + /** + * Create a {@link Channel} that listens on the {@code localAddr}. This method will block while trying to bind to the address, + * but it does not make a remote call. + */ + private static ChannelFuture bind(Initializer initializer) throws ConfigurationException + { + logger.info("Listening on {}", initializer.settings); + + ServerBootstrap bootstrap = initializer.settings.socketFactory + .newServerBootstrap() + .option(ChannelOption.SO_BACKLOG, 1 << 9) + .option(ChannelOption.ALLOCATOR, GlobalBufferPoolAllocator.instance) + .option(ChannelOption.SO_REUSEADDR, true) + .childHandler(initializer); + + int socketReceiveBufferSizeInBytes = initializer.settings.socketReceiveBufferSizeInBytes; + if (socketReceiveBufferSizeInBytes > 0) + bootstrap.childOption(ChannelOption.SO_RCVBUF, socketReceiveBufferSizeInBytes); + + InetAddressAndPort bind = initializer.settings.bindAddress; + ChannelFuture channelFuture = bootstrap.bind(new InetSocketAddress(bind.address, bind.port)); + + if (!channelFuture.awaitUninterruptibly().isSuccess()) + { + if (channelFuture.channel().isOpen()) + channelFuture.channel().close(); + + Throwable failedChannelCause = channelFuture.cause(); + + String causeString = ""; + if (failedChannelCause != null && failedChannelCause.getMessage() != null) + causeString = failedChannelCause.getMessage(); + + if (causeString.contains("in use")) + { + throw new ConfigurationException(bind + " is in use by another process. Change listen_address:storage_port " + + "in cassandra.yaml to values that do not conflict with other services"); + } + // looking at the jdk source, solaris/windows bind failue messages both use the phrase "cannot assign requested address". + // windows message uses "Cannot" (with a capital 'C'), and solaris (a/k/a *nux) doe not. hence we search for "annot" + else if (causeString.contains("annot assign requested address")) + { + throw new ConfigurationException("Unable to bind to address " + bind + + ". Set listen_address in cassandra.yaml to an interface you can bind to, e.g., your private IP address on EC2"); + } + else + { + throw new ConfigurationException("failed to bind to: " + bind, failedChannelCause); + } + } + + return channelFuture; + } + + public static ChannelFuture bind(InboundConnectionSettings settings, ChannelGroup channelGroup, + Consumer pipelineInjector) + { + return bind(new Initializer(settings, channelGroup, pipelineInjector)); + } + + /** + * 'Server-side' component that negotiates the internode handshake when establishing a new connection. + * This handler will be the first in the netty channel for each incoming connection (secure socket (TLS) notwithstanding), + * and once the handshake is successful, it will configure the proper handlers ({@link InboundMessageHandler} + * or {@link StreamingInboundHandler}) and remove itself from the working pipeline. + */ + static class Handler extends ByteToMessageDecoder + { + private final InboundConnectionSettings settings; + + private HandshakeProtocol.Initiate initiate; + private HandshakeProtocol.ConfirmOutboundPre40 confirmOutboundPre40; + + /** + * A future the essentially places a timeout on how long we'll wait for the peer + * to complete the next step of the handshake. + */ + private Future handshakeTimeout; + + Handler(InboundConnectionSettings settings) + { + this.settings = settings; + } + + /** + * On registration, immediately schedule a timeout to kill this connection if it does not handshake promptly, + * and authenticate the remote address. + */ + public void handlerAdded(ChannelHandlerContext ctx) throws Exception + { + handshakeTimeout = ctx.executor().schedule(() -> { + logger.error("Timeout handshaking with {} (on {})", SocketFactory.addressId(initiate.from, (InetSocketAddress) ctx.channel().remoteAddress()), settings.bindAddress); + failHandshake(ctx); + }, HandshakeProtocol.TIMEOUT_MILLIS, MILLISECONDS); + + logSsl(ctx); + authenticate(ctx.channel().remoteAddress()); + } + + private void authenticate(SocketAddress socketAddress) throws IOException + { + if (socketAddress.getClass().getSimpleName().equals("EmbeddedSocketAddress")) + return; + + if (!(socketAddress instanceof InetSocketAddress)) + throw new IOException(String.format("Unexpected SocketAddress type: %s, %s", socketAddress.getClass(), socketAddress)); + + InetSocketAddress addr = (InetSocketAddress)socketAddress; + if (!settings.authenticate(addr.getAddress(), addr.getPort())) + throw new IOException("Authentication failure for inbound connection from peer " + addr); + } + + private void logSsl(ChannelHandlerContext ctx) + { + SslHandler sslHandler = ctx.pipeline().get(SslHandler.class); + if (sslHandler != null) + { + SSLSession session = sslHandler.engine().getSession(); + logger.info("connection from peer {}, protocol = {}, cipher suite = {}", + ctx.channel().remoteAddress(), session.getProtocol(), session.getCipherSuite()); + } + } + + @Override + protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) throws Exception + { + if (initiate == null) initiate(ctx, in); + else if (initiate.acceptVersions == null && confirmOutboundPre40 == null) confirmPre40(ctx, in); + else throw new IllegalStateException("Should no longer be on pipeline"); + } + + void initiate(ChannelHandlerContext ctx, ByteBuf in) throws IOException + { + initiate = HandshakeProtocol.Initiate.maybeDecode(in); + if (initiate == null) + return; + + logger.trace("Received handshake initiation message from peer {}, message = {}", ctx.channel().remoteAddress(), initiate); + if (initiate.acceptVersions != null) + { + logger.trace("Connection version {} (min {}) from {}", initiate.acceptVersions.max, initiate.acceptVersions.min, initiate.from); + + final AcceptVersions accept; + + if (initiate.type.isStreaming()) + accept = settings.acceptStreaming; + else + accept = settings.acceptMessaging; + + int useMessagingVersion = max(accept.min, min(accept.max, initiate.acceptVersions.max)); + ByteBuf flush = new HandshakeProtocol.Accept(useMessagingVersion, accept.max).encode(ctx.alloc()); + + AsyncChannelPromise.writeAndFlush(ctx, flush, (ChannelFutureListener) future -> { + if (!future.isSuccess()) + exceptionCaught(future.channel(), future.cause()); + }); + + if (initiate.acceptVersions.min > accept.max) + { + logger.info("peer {} only supports messaging versions higher ({}) than this node supports ({})", ctx.channel().remoteAddress(), initiate.acceptVersions.min, current_version); + failHandshake(ctx); + } + else if (initiate.acceptVersions.max < accept.min) + { + logger.info("peer {} only supports messaging versions lower ({}) than this node supports ({})", ctx.channel().remoteAddress(), initiate.acceptVersions.max, minimum_version); + failHandshake(ctx); + } + else + { + if (initiate.type.isStreaming()) + setupStreamingPipeline(initiate.from, ctx); + else + setupMessagingPipeline(initiate.from, useMessagingVersion, initiate.acceptVersions.max, ctx.pipeline()); + } + } + else + { + int version = initiate.requestMessagingVersion; + assert version < VERSION_40 && version >= settings.acceptMessaging.min; + logger.trace("Connection version {} from {}", version, ctx.channel().remoteAddress()); + + if (initiate.type.isStreaming()) + { + // streaming connections are per-session and have a fixed version. we can't do anything with a wrong-version stream connection, so drop it. + if (version != settings.acceptStreaming.max) + { + logger.warn("Received stream using protocol version {} (my version {}). Terminating connection", version, settings.acceptStreaming.max); + failHandshake(ctx); + } + setupStreamingPipeline(initiate.from, ctx); + } + else + { + // if this version is < the MS version the other node is trying + // to connect with, the other node will disconnect + ByteBuf response = HandshakeProtocol.Accept.respondPre40(settings.acceptMessaging.max, ctx.alloc()); + AsyncChannelPromise.writeAndFlush(ctx, response, + (ChannelFutureListener) future -> { + if (!future.isSuccess()) + exceptionCaught(future.channel(), future.cause()); + }); + + if (version < VERSION_30) + throw new IOException(String.format("Unable to read obsolete message version %s from %s; The earliest version supported is 3.0.0", version, ctx.channel().remoteAddress())); + + // we don't setup the messaging pipeline here, as the legacy messaging handshake requires one more message to finish + } + } + } + + /** + * Handles the third (and last) message in the internode messaging handshake protocol for pre40 nodes. + * Grabs the protocol version and IP addr the peer wants to use. + */ + @VisibleForTesting + void confirmPre40(ChannelHandlerContext ctx, ByteBuf in) + { + confirmOutboundPre40 = HandshakeProtocol.ConfirmOutboundPre40.maybeDecode(in); + if (confirmOutboundPre40 == null) + return; + + logger.trace("Received third handshake message from peer {}, message = {}", ctx.channel().remoteAddress(), confirmOutboundPre40); + setupMessagingPipeline(confirmOutboundPre40.from, initiate.requestMessagingVersion, confirmOutboundPre40.maxMessagingVersion, ctx.pipeline()); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) + { + exceptionCaught(ctx.channel(), cause); + } + + private void exceptionCaught(Channel channel, Throwable cause) + { + logger.error("Failed to properly handshake with peer {}. Closing the channel.", channel.remoteAddress(), cause); + try + { + failHandshake(channel); + } + catch (Throwable t) + { + logger.error("Unexpected exception in {}.exceptionCaught", this.getClass().getSimpleName(), t); + } + } + + private void failHandshake(ChannelHandlerContext ctx) + { + failHandshake(ctx.channel()); + } + + private void failHandshake(Channel channel) + { + channel.close(); + if (handshakeTimeout != null) + handshakeTimeout.cancel(true); + } + + private void setupStreamingPipeline(InetAddressAndPort from, ChannelHandlerContext ctx) + { + handshakeTimeout.cancel(true); + assert initiate.framing == Framing.UNPROTECTED; + + ChannelPipeline pipeline = ctx.pipeline(); + Channel channel = ctx.channel(); + + if (from == null) + { + InetSocketAddress address = (InetSocketAddress) channel.remoteAddress(); + from = InetAddressAndPort.getByAddressOverrideDefaults(address.getAddress(), address.getPort()); + } + + BufferPool.setRecycleWhenFreeForCurrentThread(false); + pipeline.replace(this, "streamInbound", new StreamingInboundHandler(from, current_version, null)); + } + + @VisibleForTesting + void setupMessagingPipeline(InetAddressAndPort from, int useMessagingVersion, int maxMessagingVersion, ChannelPipeline pipeline) + { + handshakeTimeout.cancel(true); + // record the "true" endpoint, i.e. the one the peer is identified with, as opposed to the socket it connected over + instance().versions.set(from, maxMessagingVersion); + + BufferPool.setRecycleWhenFreeForCurrentThread(false); + BufferPoolAllocator allocator = GlobalBufferPoolAllocator.instance; + if (initiate.type == ConnectionType.LARGE_MESSAGES) + { + // for large messages, swap the global pool allocator for a local one, to optimise utilisation of chunks + allocator = new LocalBufferPoolAllocator(pipeline.channel().eventLoop()); + pipeline.channel().config().setAllocator(allocator); + } + + FrameDecoder frameDecoder; + switch (initiate.framing) + { + case LZ4: + { + if (useMessagingVersion >= VERSION_40) + frameDecoder = FrameDecoderLZ4.fast(allocator); + else + frameDecoder = new FrameDecoderLegacyLZ4(allocator, useMessagingVersion); + break; + } + case CRC: + { + if (useMessagingVersion >= VERSION_40) + { + frameDecoder = FrameDecoderCrc.create(allocator); + break; + } + } + case UNPROTECTED: + { + if (useMessagingVersion >= VERSION_40) + frameDecoder = new FrameDecoderUnprotected(allocator); + else + frameDecoder = new FrameDecoderLegacy(allocator, useMessagingVersion); + break; + } + default: + throw new AssertionError(); + } + + frameDecoder.addLastTo(pipeline); + + InboundMessageHandler handler = + settings.handlers.apply(from).createHandler(frameDecoder, initiate.type, pipeline.channel(), useMessagingVersion); + + logger.info("{} connection established, version = {}, framing = {}, encryption = {}", + handler.id(true), + useMessagingVersion, + initiate.framing, + pipeline.get("ssl") != null ? encryptionLogStatement(settings.encryption) : "disabled"); + + pipeline.addLast("deserialize", handler); + + pipeline.remove(this); + } + } + + private static class OptionalSslHandler extends ByteToMessageDecoder + { + private final EncryptionOptions.ServerEncryptionOptions encryptionOptions; + + OptionalSslHandler(EncryptionOptions.ServerEncryptionOptions encryptionOptions) + { + this.encryptionOptions = encryptionOptions; + } + + protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) throws Exception + { + if (in.readableBytes() < 5) + { + // To detect if SSL must be used we need to have at least 5 bytes, so return here and try again + // once more bytes a ready. + return; + } + + if (SslHandler.isEncrypted(in)) + { + // Connection uses SSL/TLS, replace the detection handler with a SslHandler and so use encryption. + SslContext sslContext = SSLFactory.getOrCreateSslContext(encryptionOptions, true, SSLFactory.SocketType.SERVER); + Channel channel = ctx.channel(); + InetSocketAddress peer = encryptionOptions.require_endpoint_verification ? (InetSocketAddress) channel.remoteAddress() : null; + SslHandler sslHandler = newSslHandler(channel, sslContext, peer); + ctx.pipeline().replace(this, "ssl", sslHandler); + } + else + { + // Connection use no TLS/SSL encryption, just remove the detection handler and continue without + // SslHandler in the pipeline. + ctx.pipeline().remove(this); + } + } + } +} diff --git a/src/java/org/apache/cassandra/net/InboundConnectionSettings.java b/src/java/org/apache/cassandra/net/InboundConnectionSettings.java new file mode 100644 index 000000000000..a07395b8f270 --- /dev/null +++ b/src/java/org/apache/cassandra/net/InboundConnectionSettings.java @@ -0,0 +1,213 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.net; + +import java.net.InetAddress; +import java.util.function.Function; + +import com.google.common.base.Preconditions; + +import org.apache.cassandra.auth.IInternodeAuthenticator; +import org.apache.cassandra.config.DatabaseDescriptor; +import org.apache.cassandra.config.EncryptionOptions.ServerEncryptionOptions; +import org.apache.cassandra.exceptions.ConfigurationException; +import org.apache.cassandra.locator.InetAddressAndPort; +import org.apache.cassandra.utils.FBUtilities; + +import static java.lang.String.format; +import static org.apache.cassandra.net.MessagingService.*; + +public class InboundConnectionSettings +{ + public final IInternodeAuthenticator authenticator; + public final InetAddressAndPort bindAddress; + public final ServerEncryptionOptions encryption; + public final Integer socketReceiveBufferSizeInBytes; + public final Integer applicationReceiveQueueCapacityInBytes; + public final AcceptVersions acceptMessaging; + public final AcceptVersions acceptStreaming; + public final SocketFactory socketFactory; + public final Function handlers; + + private InboundConnectionSettings(IInternodeAuthenticator authenticator, + InetAddressAndPort bindAddress, + ServerEncryptionOptions encryption, + Integer socketReceiveBufferSizeInBytes, + Integer applicationReceiveQueueCapacityInBytes, + AcceptVersions acceptMessaging, + AcceptVersions acceptStreaming, + SocketFactory socketFactory, + Function handlers) + { + this.authenticator = authenticator; + this.bindAddress = bindAddress; + this.encryption = encryption; + this.socketReceiveBufferSizeInBytes = socketReceiveBufferSizeInBytes; + this.applicationReceiveQueueCapacityInBytes = applicationReceiveQueueCapacityInBytes; + this.acceptMessaging = acceptMessaging; + this.acceptStreaming = acceptStreaming; + this.socketFactory = socketFactory; + this.handlers = handlers; + } + + public InboundConnectionSettings() + { + this(null, null, null, null, null, null, null, null, null); + } + + public boolean authenticate(InetAddressAndPort endpoint) + { + return authenticator.authenticate(endpoint.address, endpoint.port); + } + + public boolean authenticate(InetAddress address, int port) + { + return authenticator.authenticate(address, port); + } + + public String toString() + { + return format("address: (%s), nic: %s, encryption: %s", + bindAddress, FBUtilities.getNetworkInterface(bindAddress.address), SocketFactory.encryptionLogStatement(encryption)); + } + + public InboundConnectionSettings withAuthenticator(IInternodeAuthenticator authenticator) + { + return new InboundConnectionSettings(authenticator, bindAddress, encryption, + socketReceiveBufferSizeInBytes, applicationReceiveQueueCapacityInBytes, + acceptMessaging, acceptStreaming, socketFactory, handlers); + } + + @SuppressWarnings("unused") + public InboundConnectionSettings withBindAddress(InetAddressAndPort bindAddress) + { + return new InboundConnectionSettings(authenticator, bindAddress, encryption, + socketReceiveBufferSizeInBytes, applicationReceiveQueueCapacityInBytes, + acceptMessaging, acceptStreaming, socketFactory, handlers); + } + + public InboundConnectionSettings withEncryption(ServerEncryptionOptions encryption) + { + return new InboundConnectionSettings(authenticator, bindAddress, encryption, + socketReceiveBufferSizeInBytes, applicationReceiveQueueCapacityInBytes, + acceptMessaging, acceptStreaming, socketFactory, handlers); + } + + public InboundConnectionSettings withSocketReceiveBufferSizeInBytes(int socketReceiveBufferSizeInBytes) + { + return new InboundConnectionSettings(authenticator, bindAddress, encryption, + socketReceiveBufferSizeInBytes, applicationReceiveQueueCapacityInBytes, + acceptMessaging, acceptStreaming, socketFactory, handlers); + } + + @SuppressWarnings("unused") + public InboundConnectionSettings withApplicationReceiveQueueCapacityInBytes(int applicationReceiveQueueCapacityInBytes) + { + return new InboundConnectionSettings(authenticator, bindAddress, encryption, + socketReceiveBufferSizeInBytes, applicationReceiveQueueCapacityInBytes, + acceptMessaging, acceptStreaming, socketFactory, handlers); + } + + public InboundConnectionSettings withAcceptMessaging(AcceptVersions acceptMessaging) + { + return new InboundConnectionSettings(authenticator, bindAddress, encryption, + socketReceiveBufferSizeInBytes, applicationReceiveQueueCapacityInBytes, + acceptMessaging, acceptStreaming, socketFactory, handlers); + } + + public InboundConnectionSettings withAcceptStreaming(AcceptVersions acceptMessaging) + { + return new InboundConnectionSettings(authenticator, bindAddress, encryption, + socketReceiveBufferSizeInBytes, applicationReceiveQueueCapacityInBytes, + acceptMessaging, acceptStreaming, socketFactory, handlers); + } + + public InboundConnectionSettings withSocketFactory(SocketFactory socketFactory) + { + return new InboundConnectionSettings(authenticator, bindAddress, encryption, + socketReceiveBufferSizeInBytes, applicationReceiveQueueCapacityInBytes, + acceptMessaging, acceptStreaming, socketFactory, handlers); + } + + public InboundConnectionSettings withHandlers(Function handlers) + { + return new InboundConnectionSettings(authenticator, bindAddress, encryption, + socketReceiveBufferSizeInBytes, applicationReceiveQueueCapacityInBytes, + acceptMessaging, acceptStreaming, socketFactory, handlers); + } + + public InboundConnectionSettings withLegacyDefaults() + { + ServerEncryptionOptions encryption = this.encryption; + if (encryption == null) + encryption = DatabaseDescriptor.getInternodeMessagingEncyptionOptions(); + encryption = encryption.withOptional(false); + + return this.withBindAddress(bindAddress.withPort(DatabaseDescriptor.getSSLStoragePort())) + .withEncryption(encryption) + .withDefaults(); + } + + // note that connectTo is updated even if specified, in the case of pre40 messaging and using encryption (to update port) + public InboundConnectionSettings withDefaults() + { + // this is for the socket that can be plain, only ssl, or optional plain/ssl + if (bindAddress.port != DatabaseDescriptor.getStoragePort() && bindAddress.port != DatabaseDescriptor.getSSLStoragePort()) + throw new ConfigurationException(format("Local endpoint port %d doesn't match YAML configured port %d or legacy SSL port %d", + bindAddress.port, DatabaseDescriptor.getStoragePort(), DatabaseDescriptor.getSSLStoragePort())); + + IInternodeAuthenticator authenticator = this.authenticator; + ServerEncryptionOptions encryption = this.encryption; + Integer socketReceiveBufferSizeInBytes = this.socketReceiveBufferSizeInBytes; + Integer applicationReceiveQueueCapacityInBytes = this.applicationReceiveQueueCapacityInBytes; + AcceptVersions acceptMessaging = this.acceptMessaging; + AcceptVersions acceptStreaming = this.acceptStreaming; + SocketFactory socketFactory = this.socketFactory; + Function handlersFactory = this.handlers; + + if (authenticator == null) + authenticator = DatabaseDescriptor.getInternodeAuthenticator(); + + if (encryption == null) + encryption = DatabaseDescriptor.getInternodeMessagingEncyptionOptions(); + + if (socketReceiveBufferSizeInBytes == null) + socketReceiveBufferSizeInBytes = DatabaseDescriptor.getInternodeSocketReceiveBufferSizeInBytes(); + + if (applicationReceiveQueueCapacityInBytes == null) + applicationReceiveQueueCapacityInBytes = DatabaseDescriptor.getInternodeApplicationReceiveQueueCapacityInBytes(); + + if (acceptMessaging == null) + acceptMessaging = accept_messaging; + + if (acceptStreaming == null) + acceptStreaming = accept_streaming; + + if (socketFactory == null) + socketFactory = instance().socketFactory; + + if (handlersFactory == null) + handlersFactory = instance()::getInbound; + + Preconditions.checkArgument(socketReceiveBufferSizeInBytes == 0 || socketReceiveBufferSizeInBytes >= 1 << 10, "illegal socket send buffer size: " + socketReceiveBufferSizeInBytes); + Preconditions.checkArgument(applicationReceiveQueueCapacityInBytes >= 1 << 10, "illegal application receive queue capacity: " + applicationReceiveQueueCapacityInBytes); + + return new InboundConnectionSettings(authenticator, bindAddress, encryption, socketReceiveBufferSizeInBytes, applicationReceiveQueueCapacityInBytes, acceptMessaging, acceptStreaming, socketFactory, handlersFactory); + } +} diff --git a/src/java/org/apache/cassandra/net/InboundCounters.java b/src/java/org/apache/cassandra/net/InboundCounters.java new file mode 100644 index 000000000000..da035f226d8e --- /dev/null +++ b/src/java/org/apache/cassandra/net/InboundCounters.java @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.cassandra.net; + +import java.util.concurrent.atomic.AtomicLongFieldUpdater; + +/** + * Aggregates counters for (from, connection type) for the duration of host uptime. + * + * If contention/false sharing ever become a problem, consider introducing padding. + */ +class InboundCounters +{ + private volatile long errorCount; + private volatile long errorBytes; + + private static final AtomicLongFieldUpdater errorCountUpdater = + AtomicLongFieldUpdater.newUpdater(InboundCounters.class, "errorCount"); + private static final AtomicLongFieldUpdater errorBytesUpdater = + AtomicLongFieldUpdater.newUpdater(InboundCounters.class, "errorBytes"); + + void addError(int bytes) + { + errorCountUpdater.incrementAndGet(this); + errorBytesUpdater.addAndGet(this, bytes); + } + + long errorCount() + { + return errorCount; + } + + long errorBytes() + { + return errorBytes; + } + + private volatile long expiredCount; + private volatile long expiredBytes; + + private static final AtomicLongFieldUpdater expiredCountUpdater = + AtomicLongFieldUpdater.newUpdater(InboundCounters.class, "expiredCount"); + private static final AtomicLongFieldUpdater expiredBytesUpdater = + AtomicLongFieldUpdater.newUpdater(InboundCounters.class, "expiredBytes"); + + void addExpired(int bytes) + { + expiredCountUpdater.incrementAndGet(this); + expiredBytesUpdater.addAndGet(this, bytes); + } + + long expiredCount() + { + return expiredCount; + } + + long expiredBytes() + { + return expiredBytes; + } + + private volatile long processedCount; + private volatile long processedBytes; + + private static final AtomicLongFieldUpdater processedCountUpdater = + AtomicLongFieldUpdater.newUpdater(InboundCounters.class, "processedCount"); + private static final AtomicLongFieldUpdater processedBytesUpdater = + AtomicLongFieldUpdater.newUpdater(InboundCounters.class, "processedBytes"); + + void addProcessed(int bytes) + { + processedCountUpdater.incrementAndGet(this); + processedBytesUpdater.addAndGet(this, bytes); + } + + long processedCount() + { + return processedCount; + } + + long processedBytes() + { + return processedBytes; + } + + private volatile long scheduledCount; + private volatile long scheduledBytes; + + private static final AtomicLongFieldUpdater scheduledCountUpdater = + AtomicLongFieldUpdater.newUpdater(InboundCounters.class, "scheduledCount"); + private static final AtomicLongFieldUpdater scheduledBytesUpdater = + AtomicLongFieldUpdater.newUpdater(InboundCounters.class, "scheduledBytes"); + + void addPending(int bytes) + { + scheduledCountUpdater.incrementAndGet(this); + scheduledBytesUpdater.addAndGet(this, bytes); + } + + void removePending(int bytes) + { + scheduledCountUpdater.decrementAndGet(this); + scheduledBytesUpdater.addAndGet(this, -bytes); + } + + long scheduledCount() + { + return scheduledCount; + } + + long scheduledBytes() + { + return scheduledBytes; + } +} diff --git a/src/java/org/apache/cassandra/net/InboundMessageCallbacks.java b/src/java/org/apache/cassandra/net/InboundMessageCallbacks.java new file mode 100644 index 000000000000..ffa4243b9d10 --- /dev/null +++ b/src/java/org/apache/cassandra/net/InboundMessageCallbacks.java @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.cassandra.net; + +import java.util.concurrent.TimeUnit; + +import org.apache.cassandra.concurrent.Stage; +import org.apache.cassandra.net.Message.Header; + +/** + * Encapsulates the callbacks that {@link InboundMessageHandler} invokes during the lifecycle of an inbound message + * passing through it: from arrival to dispatch to execution. + * + * The flow will vary slightly between small and large messages. Small messages will be deserialized first and only + * then dispatched to one of the {@link Stage} stages for execution, whereas a large message will be dispatched first, + * and deserialized in-place on the relevant stage before being immediately processed. + * + * This difference will only show in case of deserialization failure. For large messages, it's possible for + * {@link #onFailedDeserialize(int, Header, Throwable)} to be invoked after {@link #onExecuting(int, Header, long, TimeUnit)}, + * whereas for small messages it isn't. + */ +interface InboundMessageCallbacks +{ + /** + * Invoked once the header of a message has arrived, small or large. + */ + void onHeaderArrived(int messageSize, Header header, long timeElapsed, TimeUnit unit); + + /** + * Invoked once an entire message worth of bytes has arrived, small or large. + */ + void onArrived(int messageSize, Header header, long timeElapsed, TimeUnit unit); + + /** + * Invoked if a message arrived too late to be processed, after its expiration. {@code wasCorrupt} might + * be set to {@code true} if 1+ corrupt frames were encountered while assembling an expired large message. + */ + void onArrivedExpired(int messageSize, Header header, boolean wasCorrupt, long timeElapsed, TimeUnit unit); + + /** + * Invoked if a large message arrived in time, but had one or more of its frames corrupted in flight. + */ + void onArrivedCorrupt(int messageSize, Header header, long timeElapsed, TimeUnit unit); + + /** + * Invoked if {@link InboundMessageHandler} was closed before receiving all frames of a large message. + * {@code wasCorrupt} will be set to {@code true} if some corrupt frames had been already encountered, + * {@code wasExpired} will be set to {@code true} if the message had expired in flight. + */ + void onClosedBeforeArrival(int messageSize, Header header, int bytesReceived, boolean wasCorrupt, boolean wasExpired); + + /** + * Invoked if a deserializer threw an exception while attempting to deserialize a message. + */ + void onFailedDeserialize(int messageSize, Header header, Throwable t); + + /** + * Invoked just before a message-processing task is scheduled on the appropriate {@link Stage} + * for the {@link Verb} of the message. + */ + void onDispatched(int messageSize, Header header); + + /** + * Invoked at the very beginning of execution of the message-processing task on the appropriate {@link Stage}. + */ + void onExecuting(int messageSize, Header header, long timeElapsed, TimeUnit unit); + + /** + * Invoked upon 'successful' processing of the message. Alternatively, {@link #onExpired(int, Header, long, TimeUnit)} + * will be invoked if the message had expired while waiting to be processed in the queue of the {@link Stage}. + */ + void onProcessed(int messageSize, Header header); + + /** + * Invoked if the message had expired while waiting to be processed in the queue of the {@link Stage}. Otherwise, + * {@link #onProcessed(int, Header)} will be invoked. + */ + void onExpired(int messageSize, Header header, long timeElapsed, TimeUnit unit); + + /** + * Invoked at the very end of execution of the message-processing task, no matter the outcome of processing. + */ + void onExecuted(int messageSize, Header header, long timeElapsed, TimeUnit unit); +} diff --git a/src/java/org/apache/cassandra/net/InboundMessageHandler.java b/src/java/org/apache/cassandra/net/InboundMessageHandler.java new file mode 100644 index 000000000000..891176682790 --- /dev/null +++ b/src/java/org/apache/cassandra/net/InboundMessageHandler.java @@ -0,0 +1,1194 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.cassandra.net; + +import java.io.IOException; +import java.net.InetSocketAddress; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.IdentityHashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; +import java.util.concurrent.atomic.AtomicLongFieldUpdater; +import java.util.function.Consumer; + +import com.google.common.annotations.VisibleForTesting; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import io.netty.channel.Channel; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.EventLoop; +import org.apache.cassandra.concurrent.ExecutorLocals; +import org.apache.cassandra.concurrent.Stage; +import org.apache.cassandra.concurrent.StageManager; +import org.apache.cassandra.exceptions.IncompatibleSchemaException; +import org.apache.cassandra.io.util.DataInputBuffer; +import org.apache.cassandra.locator.InetAddressAndPort; +import org.apache.cassandra.net.Message.Header; +import org.apache.cassandra.net.FrameDecoder.Frame; +import org.apache.cassandra.net.FrameDecoder.FrameProcessor; +import org.apache.cassandra.net.FrameDecoder.IntactFrame; +import org.apache.cassandra.net.FrameDecoder.CorruptFrame; +import org.apache.cassandra.net.ResourceLimits.Limit; +import org.apache.cassandra.tracing.TraceState; +import org.apache.cassandra.tracing.Tracing; +import org.apache.cassandra.utils.JVMStabilityInspector; +import org.apache.cassandra.utils.NoSpamLogger; + +import static java.lang.Math.max; +import static java.lang.Math.min; +import static java.util.concurrent.TimeUnit.NANOSECONDS; +import static org.apache.cassandra.net.Crc.*; +import static org.apache.cassandra.utils.MonotonicClock.approxTime; + +/** + * Core logic for handling inbound message deserialization and execution (in tandem with {@link FrameDecoder}). + * + * Handles small and large messages, corruption, flow control, dispatch of message processing onto an appropriate + * thread pool. + * + * # Interaction with {@link FrameDecoder} + * + * {@link InboundMessageHandler} sits on top of a {@link FrameDecoder} in the Netty pipeline, and is tightly + * coupled with it. + * + * {@link FrameDecoder} decodes inbound frames and relies on a supplied {@link FrameProcessor} to act on them. + * {@link InboundMessageHandler} provides two implementations of that interface: + * - {@link #process(Frame)} is the default, primary processor, and the primary entry point to this class + * - {@link UpToOneMessageFrameProcessor}, supplied to the decoder when the handler is reactivated after being + * put in waiting mode due to lack of acquirable reserve memory capacity permits + * + * Return value of {@link FrameProcessor#process(Frame)} determines whether the decoder should keep processing + * frames (if {@code true} is returned) or stop until explicitly reactivated (if {@code false} is). To reactivate + * the decoder (once notified of available resource permits), {@link FrameDecoder#reactivate()} is invoked. + * + * # Frames + * + * {@link InboundMessageHandler} operates on frames of messages, and there are several kinds of them: + * 1. {@link IntactFrame} that are contained. As names suggest, these contain one or multiple fully contained + * messages believed to be uncorrupted. Guaranteed to not contain an part of an incomplete message. + * See {@link #processFrameOfContainedMessages(ShareableBytes, Limit, Limit)}. + * 2. {@link IntactFrame} that are NOT contained. These are uncorrupted parts of a large message split over multiple + * parts due to their size. Can represent first or subsequent frame of a large message. + * See {@link #processFirstFrameOfLargeMessage(IntactFrame, Limit, Limit)} and + * {@link #processSubsequentFrameOfLargeMessage(Frame)}. + * 3. {@link CorruptFrame} with corrupt header. These are unrecoverable, and force a connection to be dropped. + * 4. {@link CorruptFrame} with a valid header, but corrupt payload. These can be either contained or uncontained. + * - contained frames with corrupt payload can be gracefully dropped without dropping the connection + * - uncontained frames with corrupt payload can be gracefully dropped unless they represent the first + * frame of a new large message, as in that case we don't know how many bytes to skip + * See {@link #processCorruptFrame(CorruptFrame)}. + * + * Fundamental frame invariants: + * 1. A contained frame can only have fully-encapsulated messages - 1 to n, that don't cross frame boundaries + * 2. An uncontained frame can hold a part of one message only. It can NOT, say, contain end of one large message + * and a beginning of another one. All the bytes in an uncontained frame always belong to a single message. + * + * # Small vs large messages + * + * A single handler is equipped to process both small and large messages, potentially interleaved, but the logic + * differs depending on size. Small messages are deserialized in place, and then handed off to an appropriate + * thread pool for processing. Large messages accumulate frames until completion of a message, then hand off + * the untouched frames to the correct thread pool for the verb to be deserialized there and immediately processed. + * + * See {@link LargeMessage} for details of the large-message accumulating state-machine, and {@link ProcessMessage} + * and its inheritors for the differences in execution. + * + * # Flow control (backpressure) + * + * To prevent nodes from overwhelming and bringing each other to the knees with more inbound messages that + * can be processed in a timely manner, {@link InboundMessageHandler} implements a strict flow control policy. + * + * Before we attempt to process a message fully, we first infer its size from the stream. Then we attempt to + * acquire memory permits for a message of that size. If we succeed, then we move on actually process the message. + * If we fail, the frame decoder deactivates until sufficient permits are released for the message to be processed + * and the handler is activated again. Permits are released back once the message has been fully processed - + * after the verb handler has been invoked - on the {@link Stage} for the {@link Verb} of the message. + * + * Every connection has an exclusive number of permits allocated to it (by default 4MiB). In addition to it, + * there is a per-endpoint reserve capacity and a global reserve capacity {@link Limit}, shared between all + * connections from the same host and all connections, respectively. So long as long as the handler stays within + * its exclusive limit, it doesn't need to tap into reserve capacity. + * + * If tapping into reserve capacity is necessary, but the handler fails to acquire capacity from either + * endpoint of global reserve (and it needs to acquire from both), the handler and its frame decoder become + * inactive and register with a {@link WaitQueue} of the appropriate type, depending on which of the reserves + * couldn't be tapped into. Once enough messages have finished processing and had their permits released back + * to the reserves, {@link WaitQueue} will reactivate the sleeping handlers and they'll resume processing frames. + * + * The reason we 'split' reserve capacity into two limits - endpoing and global - is to guarantee liveness, and + * prevent single endpoint's connections from taking over the whole reserve, starving other connections. + * + * One permit per byte of serialized message gets acquired. When inflated on-heap, each message will occupy more + * than that, necessarily, but despite wide variance, it's a good enough proxy that correlates with on-heap footprint. + */ +public class InboundMessageHandler extends ChannelInboundHandlerAdapter implements FrameProcessor +{ + private static final Logger logger = LoggerFactory.getLogger(InboundMessageHandler.class); + private static final NoSpamLogger noSpamLogger = NoSpamLogger.getLogger(logger, 1L, TimeUnit.SECONDS); + + private static final Message.Serializer serializer = Message.serializer; + + private final FrameDecoder decoder; + + private final ConnectionType type; + private final Channel channel; + private final InetAddressAndPort self; + private final InetAddressAndPort peer; + private final int version; + + private final int largeThreshold; + private LargeMessage largeMessage; + + private final long queueCapacity; + volatile long queueSize = 0L; + private static final AtomicLongFieldUpdater queueSizeUpdater = + AtomicLongFieldUpdater.newUpdater(InboundMessageHandler.class, "queueSize"); + + private final Limit endpointReserveCapacity; + private final WaitQueue endpointWaitQueue; + + private final Limit globalReserveCapacity; + private final WaitQueue globalWaitQueue; + + private final OnHandlerClosed onClosed; + private final InboundMessageCallbacks callbacks; + private final Consumer> consumer; + + // wait queue handle, non-null if we overrun endpoint or global capacity and request to be resumed once it's released + private WaitQueue.Ticket ticket = null; + + long corruptFramesRecovered, corruptFramesUnrecovered; + long receivedCount, receivedBytes; + long throttledCount, throttledNanos; + + private boolean isClosed; + + InboundMessageHandler(FrameDecoder decoder, + + ConnectionType type, + Channel channel, + InetAddressAndPort self, + InetAddressAndPort peer, + int version, + int largeThreshold, + + long queueCapacity, + Limit endpointReserveCapacity, + Limit globalReserveCapacity, + WaitQueue endpointWaitQueue, + WaitQueue globalWaitQueue, + + OnHandlerClosed onClosed, + InboundMessageCallbacks callbacks, + Consumer> consumer) + { + this.decoder = decoder; + + this.type = type; + this.channel = channel; + this.self = self; + this.peer = peer; + this.version = version; + this.largeThreshold = largeThreshold; + + this.queueCapacity = queueCapacity; + this.endpointReserveCapacity = endpointReserveCapacity; + this.endpointWaitQueue = endpointWaitQueue; + this.globalReserveCapacity = globalReserveCapacity; + this.globalWaitQueue = globalWaitQueue; + + this.onClosed = onClosed; + this.callbacks = callbacks; + this.consumer = consumer; + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) + { + /* + * InboundMessageHandler works in tandem with FrameDecoder to implement flow control + * and work stashing optimally. We rely on FrameDecoder to invoke the provided + * FrameProcessor rather than on the pipeline and invocations of channelRead(). + * process(Frame) is the primary entry point for this class. + */ + throw new IllegalStateException("InboundMessageHandler doesn't expect channelRead() to be invoked"); + } + + @Override + public void handlerAdded(ChannelHandlerContext ctx) + { + decoder.activate(this); // the frame decoder starts inactive until explicitly activated by the added inbound message handler + } + + @Override + public boolean process(Frame frame) throws IOException + { + if (frame instanceof IntactFrame) + return processIntactFrame((IntactFrame) frame, endpointReserveCapacity, globalReserveCapacity); + + processCorruptFrame((CorruptFrame) frame); + return true; + } + + private boolean processIntactFrame(IntactFrame frame, Limit endpointReserve, Limit globalReserve) throws IOException + { + if (frame.isSelfContained) + return processFrameOfContainedMessages(frame.contents, endpointReserve, globalReserve); + else if (null == largeMessage) + return processFirstFrameOfLargeMessage(frame, endpointReserve, globalReserve); + else + return processSubsequentFrameOfLargeMessage(frame); + } + + /* + * Handle contained messages (not crossing boundaries of the frame) - both small and large, for the inbound + * definition of large (breaching the size threshold for what we are willing to process on event-loop vs. + * off event-loop). + */ + private boolean processFrameOfContainedMessages(ShareableBytes bytes, Limit endpointReserve, Limit globalReserve) throws IOException + { + while (bytes.hasRemaining()) + if (!processOneContainedMessage(bytes, endpointReserve, globalReserve)) + return false; + return true; + } + + private boolean processOneContainedMessage(ShareableBytes bytes, Limit endpointReserve, Limit globalReserve) throws IOException + { + ByteBuffer buf = bytes.get(); + + long currentTimeNanos = approxTime.now(); + Header header = serializer.extractHeader(buf, peer, currentTimeNanos, version); + long timeElapsed = currentTimeNanos - header.createdAtNanos; + int size = serializer.inferMessageSize(buf, buf.position(), buf.limit(), version); + + if (approxTime.isAfter(currentTimeNanos, header.expiresAtNanos)) + { + callbacks.onHeaderArrived(size, header, timeElapsed, NANOSECONDS); + callbacks.onArrivedExpired(size, header, false, timeElapsed, NANOSECONDS); + receivedCount++; + receivedBytes += size; + bytes.skipBytes(size); + return true; + } + + if (!acquireCapacity(endpointReserve, globalReserve, size, currentTimeNanos, header.expiresAtNanos)) + return false; + + callbacks.onHeaderArrived(size, header, timeElapsed, NANOSECONDS); + callbacks.onArrived(size, header, timeElapsed, NANOSECONDS); + receivedCount++; + receivedBytes += size; + + if (size <= largeThreshold) + processSmallMessage(bytes, size, header); + else + processLargeMessage(bytes, size, header); + + return true; + } + + private void processSmallMessage(ShareableBytes bytes, int size, Header header) + { + ByteBuffer buf = bytes.get(); + final int begin = buf.position(); + final int end = buf.limit(); + buf.limit(begin + size); // cap to expected message size + + Message message = null; + try (DataInputBuffer in = new DataInputBuffer(buf, false)) + { + Message m = serializer.deserialize(in, header, version); + if (in.available() > 0) // bytes remaining after deser: deserializer is busted + throw new InvalidSerializedSizeException(header.verb, size, size - in.available()); + message = m; + } + catch (IncompatibleSchemaException e) + { + callbacks.onFailedDeserialize(size, header, e); + noSpamLogger.info("{} incompatible schema encountered while deserializing a message", id(), e); + } + catch (Throwable t) + { + JVMStabilityInspector.inspectThrowable(t, false); + callbacks.onFailedDeserialize(size, header, t); + logger.error("{} unexpected exception caught while deserializing a message", id(), t); + } + finally + { + if (null == message) + releaseCapacity(size); + + // no matter what, set position to the beginning of the next message and restore limit, so that + // we can always keep on decoding the frame even on failure to deserialize previous message + buf.position(begin + size); + buf.limit(end); + } + + if (null != message) + dispatch(new ProcessSmallMessage(message, size)); + } + + // for various reasons, it's possible for a large message to be contained in a single frame + private void processLargeMessage(ShareableBytes bytes, int size, Header header) + { + new LargeMessage(size, header, bytes.sliceAndConsume(size).share()).schedule(); + } + + /* + * Handling of multi-frame large messages + */ + + private boolean processFirstFrameOfLargeMessage(IntactFrame frame, Limit endpointReserve, Limit globalReserve) throws IOException + { + ShareableBytes bytes = frame.contents; + ByteBuffer buf = bytes.get(); + + long currentTimeNanos = approxTime.now(); + Header header = serializer.extractHeader(buf, peer, currentTimeNanos, version); + int size = serializer.inferMessageSize(buf, buf.position(), buf.limit(), version); + + boolean expired = approxTime.isAfter(currentTimeNanos, header.expiresAtNanos); + if (!expired && !acquireCapacity(endpointReserve, globalReserve, size, currentTimeNanos, header.expiresAtNanos)) + return false; + + callbacks.onHeaderArrived(size, header, currentTimeNanos - header.createdAtNanos, NANOSECONDS); + receivedBytes += buf.remaining(); + largeMessage = new LargeMessage(size, header, expired); + largeMessage.supply(frame); + return true; + } + + private boolean processSubsequentFrameOfLargeMessage(Frame frame) + { + receivedBytes += frame.frameSize; + if (largeMessage.supply(frame)) + { + receivedCount++; + largeMessage = null; + } + return true; + } + + /* + * We can handle some corrupt frames gracefully without dropping the connection and losing all the + * queued up messages, but not others. + * + * Corrupt frames that *ARE NOT* safe to skip gracefully and require the connection to be dropped: + * - any frame with corrupt header (!frame.isRecoverable()) + * - first corrupt-payload frame of a large message (impossible to infer message size, and without it + * impossible to skip the message safely + * + * Corrupt frames that *ARE* safe to skip gracefully, without reconnecting: + * - any self-contained frame with a corrupt payload (but not header): we lose all the messages in the + * frame, but that has no effect on subsequent ones + * - any non-first payload-corrupt frame of a large message: we know the size of the large message in + * flight, so we just skip frames until we've seen all its bytes; we only lose the large message + */ + private void processCorruptFrame(CorruptFrame frame) throws InvalidCrc + { + if (!frame.isRecoverable()) + { + corruptFramesUnrecovered++; + throw new InvalidCrc(frame.readCRC, frame.computedCRC); + } + else if (frame.isSelfContained) + { + receivedBytes += frame.frameSize; + corruptFramesRecovered++; + noSpamLogger.warn("{} invalid, recoverable CRC mismatch detected while reading messages (corrupted self-contained frame)", id()); + } + else if (null == largeMessage) // first frame of a large message + { + receivedBytes += frame.frameSize; + corruptFramesUnrecovered++; + noSpamLogger.error("{} invalid, unrecoverable CRC mismatch detected while reading messages (corrupted first frame of a large message)", id()); + throw new InvalidCrc(frame.readCRC, frame.computedCRC); + } + else // subsequent frame of a large message + { + processSubsequentFrameOfLargeMessage(frame); + corruptFramesRecovered++; + noSpamLogger.warn("{} invalid, recoverable CRC mismatch detected while reading a large message", id()); + } + } + + private void onEndpointReserveCapacityRegained(Limit endpointReserve, long elapsedNanos) + { + onReserveCapacityRegained(endpointReserve, globalReserveCapacity, elapsedNanos); + } + + private void onGlobalReserveCapacityRegained(Limit globalReserve, long elapsedNanos) + { + onReserveCapacityRegained(endpointReserveCapacity, globalReserve, elapsedNanos); + } + + private void onReserveCapacityRegained(Limit endpointReserve, Limit globalReserve, long elapsedNanos) + { + if (isClosed) + return; + + assert channel.eventLoop().inEventLoop(); + + ticket = null; + throttledNanos += elapsedNanos; + + try + { + /* + * Process up to one message using supplied overriden reserves - one of them pre-allocated, + * and guaranteed to be enough for one message - then, if no obstacles enountered, reactivate + * the frame decoder using normal reserve capacities. + */ + if (processUpToOneMessage(endpointReserve, globalReserve)) + decoder.reactivate(); + } + catch (Throwable t) + { + exceptionCaught(t); + } + } + + // return true if the handler should be reactivated - if no new hurdles were encountered, + // like running out of the other kind of reserve capacity + private boolean processUpToOneMessage(Limit endpointReserve, Limit globalReserve) throws IOException + { + UpToOneMessageFrameProcessor processor = new UpToOneMessageFrameProcessor(endpointReserve, globalReserve); + decoder.processBacklog(processor); + return processor.isActive; + } + + /* + * Process at most one message. Won't always be an entire one (if the message in the head of line + * is a large one, and there aren't sufficient frames to decode it entirely), but will never be more than one. + */ + private class UpToOneMessageFrameProcessor implements FrameProcessor + { + private final Limit endpointReserve; + private final Limit globalReserve; + + boolean isActive = true; + boolean firstFrame = true; + + private UpToOneMessageFrameProcessor(Limit endpointReserve, Limit globalReserve) + { + this.endpointReserve = endpointReserve; + this.globalReserve = globalReserve; + } + + @Override + public boolean process(Frame frame) throws IOException + { + if (firstFrame) + { + if (!(frame instanceof IntactFrame)) + throw new IllegalStateException("First backlog frame must be intact"); + firstFrame = false; + return processFirstFrame((IntactFrame) frame); + } + + return processSubsequentFrame(frame); + } + + private boolean processFirstFrame(IntactFrame frame) throws IOException + { + if (frame.isSelfContained) + { + isActive = processOneContainedMessage(frame.contents, endpointReserve, globalReserve); + return false; // stop after one message + } + else + { + isActive = processFirstFrameOfLargeMessage(frame, endpointReserve, globalReserve); + return isActive; // continue unless fallen behind coprocessor or ran out of reserve capacity again + } + } + + private boolean processSubsequentFrame(Frame frame) throws IOException + { + if (frame instanceof IntactFrame) + processSubsequentFrameOfLargeMessage(frame); + else + processCorruptFrame((CorruptFrame) frame); + + return largeMessage != null; // continue until done with the large message + } + } + + /** + * Try to acquire permits for the inbound message. In case of failure, register with the right wait queue to be + * reactivated once permit capacity is regained. + */ + @SuppressWarnings("BooleanMethodIsAlwaysInverted") + private boolean acquireCapacity(Limit endpointReserve, Limit globalReserve, int bytes, long currentTimeNanos, long expiresAtNanos) + { + ResourceLimits.Outcome outcome = acquireCapacity(endpointReserve, globalReserve, bytes); + + if (outcome == ResourceLimits.Outcome.INSUFFICIENT_ENDPOINT) + ticket = endpointWaitQueue.register(this, bytes, currentTimeNanos, expiresAtNanos); + else if (outcome == ResourceLimits.Outcome.INSUFFICIENT_GLOBAL) + ticket = globalWaitQueue.register(this, bytes, currentTimeNanos, expiresAtNanos); + + if (outcome != ResourceLimits.Outcome.SUCCESS) + throttledCount++; + + return outcome == ResourceLimits.Outcome.SUCCESS; + } + + private ResourceLimits.Outcome acquireCapacity(Limit endpointReserve, Limit globalReserve, int bytes) + { + long currentQueueSize = queueSize; + + /* + * acquireCapacity() is only ever called on the event loop, and as such queueSize is only ever increased + * on the event loop. If there is enough capacity, we can safely addAndGet() and immediately return. + */ + if (currentQueueSize + bytes <= queueCapacity) + { + queueSizeUpdater.addAndGet(this, bytes); + return ResourceLimits.Outcome.SUCCESS; + } + + // we know we don't have enough local queue capacity for the entire message, so we need to borrow some from reserve capacity + long allocatedExcess = min(currentQueueSize + bytes - queueCapacity, bytes); + + if (!globalReserve.tryAllocate(allocatedExcess)) + return ResourceLimits.Outcome.INSUFFICIENT_GLOBAL; + + if (!endpointReserve.tryAllocate(allocatedExcess)) + { + globalReserve.release(allocatedExcess); + globalWaitQueue.signal(); + return ResourceLimits.Outcome.INSUFFICIENT_GLOBAL; + } + + long newQueueSize = queueSizeUpdater.addAndGet(this, bytes); + long actualExcess = max(0, min(newQueueSize - queueCapacity, bytes)); + + /* + * It's possible that some permits were released at some point after we loaded current queueSize, + * and we can satisfy more of the permits using our exclusive per-connection capacity, needing + * less than previously estimated from the reserves. If that's the case, release the now unneeded + * permit excess back to endpoint/global reserves. + */ + if (actualExcess != allocatedExcess) // actualExcess < allocatedExcess + { + long excess = allocatedExcess - actualExcess; + + endpointReserve.release(excess); + globalReserve.release(excess); + + endpointWaitQueue.signal(); + globalWaitQueue.signal(); + } + + return ResourceLimits.Outcome.SUCCESS; + } + + private void releaseCapacity(int bytes) + { + long oldQueueSize = queueSizeUpdater.getAndAdd(this, -bytes); + if (oldQueueSize > queueCapacity) + { + long excess = min(oldQueueSize - queueCapacity, bytes); + + endpointReserveCapacity.release(excess); + globalReserveCapacity.release(excess); + + endpointWaitQueue.signal(); + globalWaitQueue.signal(); + } + } + + /** + * Invoked to release capacity for a message that has been fully, successfully processed. + * + * Normally no different from invoking {@link #releaseCapacity(int)}, but is necessary for the verifier + * to be able to delay capacity release for backpressure testing. + */ + @VisibleForTesting + protected void releaseProcessedCapacity(int size, Header header) + { + releaseCapacity(size); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) + { + try + { + exceptionCaught(cause); + } + catch (Throwable t) + { + logger.error("Unexpected exception in {}.exceptionCaught", this.getClass().getSimpleName(), t); + } + } + + private void exceptionCaught(Throwable cause) + { + decoder.discard(); + + JVMStabilityInspector.inspectThrowable(cause, false); + + if (cause instanceof Message.InvalidLegacyProtocolMagic) + logger.error("{} invalid, unrecoverable CRC mismatch detected while reading messages - closing the connection", id()); + else + logger.error("{} unexpected exception caught while processing inbound messages; terminating connection", id(), cause); + + channel.close(); + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) + { + isClosed = true; + + if (null != largeMessage) + largeMessage.abort(); + + if (null != ticket) + ticket.invalidate(); + + onClosed.call(this); + } + + private EventLoop eventLoop() + { + return channel.eventLoop(); + } + + String id(boolean includeReal) + { + if (!includeReal) + return id(); + + return SocketFactory.channelId(peer, (InetSocketAddress) channel.remoteAddress(), + self, (InetSocketAddress) channel.localAddress(), + type, channel.id().asShortText()); + } + + String id() + { + return SocketFactory.channelId(peer, self, type, channel.id().asShortText()); + } + + /* + * A large-message frame-accumulating state machine. + * + * Collects intact frames until it's has all the bytes necessary to deserialize the large message, + * at which point it schedules a task on the appropriate {@link Stage}, + * a task that deserializes the message and immediately invokes the verb handler. + * + * Also handles corrupt frames and potential expiry of the large message during accumulation: + * if it's taking the frames too long to arrive, there is no point in holding on to the + * accumulated frames, or in gathering more - so we release the ones we already have, and + * skip any remaining ones, alongside with returning memory permits early. + */ + private class LargeMessage + { + private final int size; + private final Header header; + + private final List buffers = new ArrayList<>(); + private int received; + + private boolean isExpired; + private boolean isCorrupt; + + private LargeMessage(int size, Header header, boolean isExpired) + { + this.size = size; + this.header = header; + this.isExpired = isExpired; + } + + private LargeMessage(int size, Header header, ShareableBytes bytes) + { + this(size, header, false); + buffers.add(bytes); + } + + private void schedule() + { + dispatch(new ProcessLargeMessage(this)); + } + + /** + * Return true if this was the last frame of the large message. + */ + private boolean supply(Frame frame) + { + if (frame instanceof IntactFrame) + onIntactFrame((IntactFrame) frame); + else + onCorruptFrame(); + + received += frame.frameSize; + if (size == received) + onComplete(); + return size == received; + } + + private void onIntactFrame(IntactFrame frame) + { + boolean expires = approxTime.isAfter(header.expiresAtNanos); + if (!isExpired && !isCorrupt) + { + if (!expires) + { + buffers.add(frame.contents.sliceAndConsume(frame.frameSize).share()); + return; + } + releaseBuffersAndCapacity(); // release resources once we transition from normal state to expired + } + frame.consume(); + isExpired |= expires; + } + + private void onCorruptFrame() + { + if (!isExpired && !isCorrupt) + releaseBuffersAndCapacity(); // release resources once we transition from normal state to corrupt + isCorrupt = true; + isExpired |= approxTime.isAfter(header.expiresAtNanos); + } + + private void onComplete() + { + long timeElapsed = approxTime.now() - header.createdAtNanos; + + if (!isExpired && !isCorrupt) + { + callbacks.onArrived(size, header, timeElapsed, NANOSECONDS); + schedule(); + } + else if (isExpired) + { + callbacks.onArrivedExpired(size, header, isCorrupt, timeElapsed, NANOSECONDS); + } + else + { + callbacks.onArrivedCorrupt(size, header, timeElapsed, NANOSECONDS); + } + } + + private void abort() + { + if (!isExpired && !isCorrupt) + releaseBuffersAndCapacity(); // release resources if in normal state when abort() is invoked + callbacks.onClosedBeforeArrival(size, header, received, isCorrupt, isExpired); + } + + private void releaseBuffers() + { + buffers.forEach(ShareableBytes::release); buffers.clear(); + } + + private void releaseBuffersAndCapacity() + { + releaseBuffers(); releaseCapacity(size); + } + + private Message deserialize() + { + try (ChunkedInputPlus input = ChunkedInputPlus.of(buffers)) + { + Message m = serializer.deserialize(input, header, version); + int remainder = input.remainder(); + if (remainder > 0) + throw new InvalidSerializedSizeException(header.verb, size, size - remainder); + return m; + } + catch (IncompatibleSchemaException e) + { + callbacks.onFailedDeserialize(size, header, e); + noSpamLogger.info("{} incompatible schema encountered while deserializing a message", id(), e); + } + catch (Throwable t) + { + JVMStabilityInspector.inspectThrowable(t, false); + callbacks.onFailedDeserialize(size, header, t); + logger.error("{} unexpected exception caught while deserializing a message", id(), t); + } + finally + { + buffers.clear(); // closing the input will have ensured that the buffers were released no matter what + } + + return null; + } + } + + /** + * Submit a {@link ProcessMessage} task to the appropriate {@link Stage} for the {@link Verb}. + */ + private void dispatch(ProcessMessage task) + { + Header header = task.header(); + + TraceState state = Tracing.instance.initializeFromMessage(header); + if (state != null) state.trace("{} message received from {}", header.verb, header.from); + + callbacks.onDispatched(task.size(), header); + StageManager.getStage(header.verb.stage).execute(task, ExecutorLocals.create(state)); + } + + private abstract class ProcessMessage implements Runnable + { + /** + * Actually handle the message. Runs on the appropriate {@link Stage} for the {@link Verb}. + * + * Small messages will come pre-deserialized. Large messages will be deserialized on the stage, + * just in time, and only then processed. + */ + public void run() + { + Header header = header(); + long currentTimeNanos = approxTime.now(); + boolean expired = approxTime.isAfter(currentTimeNanos, header.expiresAtNanos); + + boolean processed = false; + try + { + callbacks.onExecuting(size(), header, currentTimeNanos - header.createdAtNanos, NANOSECONDS); + + if (expired) + { + callbacks.onExpired(size(), header, currentTimeNanos - header.createdAtNanos, NANOSECONDS); + return; + } + + Message message = provideMessage(); + if (null != message) + { + consumer.accept(message); + processed = true; + callbacks.onProcessed(size(), header); + } + } + finally + { + if (processed) + releaseProcessedCapacity(size(), header); + else + releaseCapacity(size()); + + releaseResources(); + + callbacks.onExecuted(size(), header, approxTime.now() - currentTimeNanos, NANOSECONDS); + } + } + + abstract int size(); + abstract Header header(); + abstract Message provideMessage(); + void releaseResources() {} + } + + private class ProcessSmallMessage extends ProcessMessage + { + private final int size; + private final Message message; + + ProcessSmallMessage(Message message, int size) + { + this.size = size; + this.message = message; + } + + int size() + { + return size; + } + + Header header() + { + return message.header; + } + + Message provideMessage() + { + return message; + } + } + + private class ProcessLargeMessage extends ProcessMessage + { + private final LargeMessage message; + + ProcessLargeMessage(LargeMessage message) + { + this.message = message; + } + + int size() + { + return message.size; + } + + Header header() + { + return message.header; + } + + Message provideMessage() + { + return message.deserialize(); + } + + @Override + void releaseResources() + { + message.releaseBuffers(); // releases buffers if they haven't been yet (by deserialize() call) + } + } + + /** + * A special-purpose wait queue to park inbound message handlers that failed to allocate + * reserve capacity for a message in. Upon such failure a handler registers itself with + * a {@link WaitQueue} of the appropriate kind (either ENDPOINT or GLOBAL - if failed + * to allocate endpoint or global reserve capacity, respectively), stops processing any + * accumulated frames or receiving new ones, and waits - until reactivated. + * + * Every time permits are returned to an endpoint or global {@link Limit}, the respective + * queue gets signalled, and if there are any handlers registered in it, we will attempt + * to reactivate as many waiting handlers as current available reserve capacity allows + * us to - immediately, on the {@link #signal()}-calling thread. At most one such attempt + * will be in progress at any given time. + * + * Handlers that can be reactivated will be grouped by their {@link EventLoop} and a single + * {@link ReactivateHandlers} task will be scheduled per event loop, on the corresponding + * event loops. + * + * When run, the {@link ReactivateHandlers} task will ask each handler in its group to first + * process one message - using preallocated reserve capacity - and if no obstacles were met - + * reactivate the handlers, this time using their regular reserves. + * + * See {@link WaitQueue#schedule()}, {@link ReactivateHandlers#run()}, {@link Ticket#reactivateHandler(Limit)}. + */ + public static final class WaitQueue + { + enum Kind { ENDPOINT, GLOBAL } + + private static final int NOT_RUNNING = 0; + @SuppressWarnings("unused") + private static final int RUNNING = 1; + private static final int RUN_AGAIN = 2; + + private volatile int scheduled; + private static final AtomicIntegerFieldUpdater scheduledUpdater = + AtomicIntegerFieldUpdater.newUpdater(WaitQueue.class, "scheduled"); + + private final Kind kind; + private final Limit reserveCapacity; + + private final ManyToOneConcurrentLinkedQueue queue = new ManyToOneConcurrentLinkedQueue<>(); + + private WaitQueue(Kind kind, Limit reserveCapacity) + { + this.kind = kind; + this.reserveCapacity = reserveCapacity; + } + + public static WaitQueue endpoint(Limit endpointReserveCapacity) + { + return new WaitQueue(Kind.ENDPOINT, endpointReserveCapacity); + } + + public static WaitQueue global(Limit globalReserveCapacity) + { + return new WaitQueue(Kind.GLOBAL, globalReserveCapacity); + } + + private Ticket register(InboundMessageHandler handler, int bytesRequested, long registeredAtNanos, long expiresAtNanos) + { + Ticket ticket = new Ticket(this, handler, bytesRequested, registeredAtNanos, expiresAtNanos); + Ticket previous = queue.relaxedPeekLastAndOffer(ticket); + if (null == previous || !previous.isWaiting()) + signal(); // only signal the queue if this handler is first to register + return ticket; + } + + private void signal() + { + if (queue.relaxedIsEmpty()) + return; // we can return early if no handlers have registered with the wait queue + + if (NOT_RUNNING == scheduledUpdater.getAndUpdate(this, i -> min(RUN_AGAIN, i + 1))) + { + do + { + schedule(); + } + while (RUN_AGAIN == scheduledUpdater.getAndDecrement(this)); + } + } + + private void schedule() + { + Map tasks = null; + + long currentTimeNanos = approxTime.now(); + + Ticket t; + while ((t = queue.peek()) != null) + { + if (!t.call()) // invalidated + { + queue.remove(); + continue; + } + + boolean isLive = t.isLive(currentTimeNanos); + if (isLive && !reserveCapacity.tryAllocate(t.bytesRequested)) + { + if (!t.reset()) // the ticket was invalidated after being called but before now + { + queue.remove(); + continue; + } + break; // TODO: traverse the entire queue to unblock handlers that have expired or invalidated tickets + } + + if (null == tasks) + tasks = new IdentityHashMap<>(); + + queue.remove(); + tasks.computeIfAbsent(t.handler.eventLoop(), e -> new ReactivateHandlers()).add(t, isLive); + } + + if (null != tasks) + tasks.forEach(EventLoop::execute); + } + + private class ReactivateHandlers implements Runnable + { + List tickets = new ArrayList<>(); + long capacity = 0L; + + private void add(Ticket ticket, boolean isLive) + { + tickets.add(ticket); + if (isLive) capacity += ticket.bytesRequested; + } + + public void run() + { + Limit limit = new ResourceLimits.Basic(capacity); + try + { + for (Ticket ticket : tickets) + ticket.reactivateHandler(limit); + } + finally + { + /* + * Free up any unused capacity, if any. Will be non-zero if one or more handlers were closed + * when we attempted to run their callback, or used more of their other reserve; or if the first + * message in the unprocessed stream has expired in the narrow time window. + */ + long remaining = limit.remaining(); + if (remaining > 0) + { + reserveCapacity.release(remaining); + signal(); + } + } + } + } + + private static final class Ticket + { + private static final int WAITING = 0; + private static final int CALLED = 1; + private static final int INVALIDATED = 2; // invalidated by a handler that got closed + + private volatile int state; + private static final AtomicIntegerFieldUpdater stateUpdater = + AtomicIntegerFieldUpdater.newUpdater(Ticket.class, "state"); + + private final WaitQueue waitQueue; + private final InboundMessageHandler handler; + private final int bytesRequested; + private final long reigsteredAtNanos; + private final long expiresAtNanos; + + private Ticket(WaitQueue waitQueue, InboundMessageHandler handler, int bytesRequested, long registeredAtNanos, long expiresAtNanos) + { + this.waitQueue = waitQueue; + this.handler = handler; + this.bytesRequested = bytesRequested; + this.reigsteredAtNanos = registeredAtNanos; + this.expiresAtNanos = expiresAtNanos; + } + + private void reactivateHandler(Limit capacity) + { + long elapsedNanos = approxTime.now() - reigsteredAtNanos; + try + { + if (waitQueue.kind == Kind.ENDPOINT) + handler.onEndpointReserveCapacityRegained(capacity, elapsedNanos); + else + handler.onGlobalReserveCapacityRegained(capacity, elapsedNanos); + } + catch (Throwable t) + { + logger.error("{} exception caught while reactivating a handler", handler.id(), t); + } + } + + private boolean isWaiting() + { + return state == WAITING; + } + + private boolean isLive(long currentTimeNanos) + { + return !approxTime.isAfter(currentTimeNanos, expiresAtNanos); + } + + private void invalidate() + { + state = INVALIDATED; + waitQueue.signal(); + } + + private boolean call() + { + return stateUpdater.compareAndSet(this, WAITING, CALLED); + } + + private boolean reset() + { + return stateUpdater.compareAndSet(this, CALLED, WAITING); + } + } + } + + public interface OnHandlerClosed + { + void call(InboundMessageHandler handler); + } +} diff --git a/src/java/org/apache/cassandra/net/InboundMessageHandlers.java b/src/java/org/apache/cassandra/net/InboundMessageHandlers.java new file mode 100644 index 000000000000..4ebd5ad76afb --- /dev/null +++ b/src/java/org/apache/cassandra/net/InboundMessageHandlers.java @@ -0,0 +1,447 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.cassandra.net; + +import java.util.Collection; +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicLongFieldUpdater; +import java.util.function.Consumer; +import java.util.function.ToLongFunction; + +import io.netty.channel.Channel; +import org.apache.cassandra.locator.InetAddressAndPort; +import org.apache.cassandra.metrics.InternodeInboundMetrics; +import org.apache.cassandra.net.Message.Header; +import org.apache.cassandra.utils.ApproximateTime; + +import static java.util.concurrent.TimeUnit.NANOSECONDS; +import static org.apache.cassandra.utils.MonotonicClock.approxTime; + +/** + * An aggregation of {@link InboundMessageHandler}s for all connections from a peer. + * + * Manages metrics and shared resource limits. Can have multiple connections of a single + * type open simultaneousely (legacy in particular). + */ +public final class InboundMessageHandlers +{ + private final InetAddressAndPort self; + private final InetAddressAndPort peer; + + private final int queueCapacity; + private final ResourceLimits.Limit endpointReserveCapacity; + private final ResourceLimits.Limit globalReserveCapacity; + + private final InboundMessageHandler.WaitQueue endpointWaitQueue; + private final InboundMessageHandler.WaitQueue globalWaitQueue; + + private final InboundCounters urgentCounters = new InboundCounters(); + private final InboundCounters smallCounters = new InboundCounters(); + private final InboundCounters largeCounters = new InboundCounters(); + private final InboundCounters legacyCounters = new InboundCounters(); + + private final InboundMessageCallbacks urgentCallbacks; + private final InboundMessageCallbacks smallCallbacks; + private final InboundMessageCallbacks largeCallbacks; + private final InboundMessageCallbacks legacyCallbacks; + + private final InternodeInboundMetrics metrics; + private final MessageConsumer messageConsumer; + + private final HandlerProvider handlerProvider; + private final Collection handlers = new CopyOnWriteArrayList<>(); + + static class GlobalResourceLimits + { + final ResourceLimits.Limit reserveCapacity; + final InboundMessageHandler.WaitQueue waitQueue; + + GlobalResourceLimits(ResourceLimits.Limit reserveCapacity) + { + this.reserveCapacity = reserveCapacity; + this.waitQueue = InboundMessageHandler.WaitQueue.global(reserveCapacity); + } + } + + public interface MessageConsumer extends Consumer> + { + void fail(Message.Header header, Throwable failure); + } + + public interface GlobalMetricCallbacks + { + LatencyConsumer internodeLatencyRecorder(InetAddressAndPort to); + void recordInternalLatency(Verb verb, long timeElapsed, TimeUnit timeUnit); + void recordInternodeDroppedMessage(Verb verb, long timeElapsed, TimeUnit timeUnit); + } + + public InboundMessageHandlers(InetAddressAndPort self, + InetAddressAndPort peer, + int queueCapacity, + long endpointReserveCapacity, + GlobalResourceLimits globalResourceLimits, + GlobalMetricCallbacks globalMetricCallbacks, + MessageConsumer messageConsumer) + { + this(self, peer, queueCapacity, endpointReserveCapacity, globalResourceLimits, globalMetricCallbacks, messageConsumer, InboundMessageHandler::new); + } + + public InboundMessageHandlers(InetAddressAndPort self, + InetAddressAndPort peer, + int queueCapacity, + long endpointReserveCapacity, + GlobalResourceLimits globalResourceLimits, + GlobalMetricCallbacks globalMetricCallbacks, + MessageConsumer messageConsumer, + HandlerProvider handlerProvider) + { + this.self = self; + this.peer = peer; + + this.queueCapacity = queueCapacity; + this.endpointReserveCapacity = new ResourceLimits.Concurrent(endpointReserveCapacity); + this.globalReserveCapacity = globalResourceLimits.reserveCapacity; + this.endpointWaitQueue = InboundMessageHandler.WaitQueue.endpoint(this.endpointReserveCapacity); + this.globalWaitQueue = globalResourceLimits.waitQueue; + this.messageConsumer = messageConsumer; + + this.handlerProvider = handlerProvider; + + urgentCallbacks = makeMessageCallbacks(peer, urgentCounters, globalMetricCallbacks, messageConsumer); + smallCallbacks = makeMessageCallbacks(peer, smallCounters, globalMetricCallbacks, messageConsumer); + largeCallbacks = makeMessageCallbacks(peer, largeCounters, globalMetricCallbacks, messageConsumer); + legacyCallbacks = makeMessageCallbacks(peer, legacyCounters, globalMetricCallbacks, messageConsumer); + + metrics = new InternodeInboundMetrics(peer, this); + } + + InboundMessageHandler createHandler(FrameDecoder frameDecoder, ConnectionType type, Channel channel, int version) + { + InboundMessageHandler handler = + handlerProvider.provide(frameDecoder, + + type, + channel, + self, + peer, + version, + OutboundConnections.LARGE_MESSAGE_THRESHOLD, + + queueCapacity, + endpointReserveCapacity, + globalReserveCapacity, + endpointWaitQueue, + globalWaitQueue, + + this::onHandlerClosed, + callbacksFor(type), + messageConsumer); + handlers.add(handler); + return handler; + } + + void releaseMetrics() + { + metrics.release(); + } + + private void onHandlerClosed(InboundMessageHandler handler) + { + handlers.remove(handler); + absorbCounters(handler); + } + + /* + * Message callbacks + */ + + private InboundMessageCallbacks callbacksFor(ConnectionType type) + { + switch (type) + { + case URGENT_MESSAGES: return urgentCallbacks; + case SMALL_MESSAGES: return smallCallbacks; + case LARGE_MESSAGES: return largeCallbacks; + case LEGACY_MESSAGES: return legacyCallbacks; + } + + throw new IllegalArgumentException(); + } + + private static InboundMessageCallbacks makeMessageCallbacks(InetAddressAndPort peer, InboundCounters counters, GlobalMetricCallbacks globalMetrics, MessageConsumer messageConsumer) + { + LatencyConsumer internodeLatency = globalMetrics.internodeLatencyRecorder(peer); + + return new InboundMessageCallbacks() + { + @Override + public void onHeaderArrived(int messageSize, Header header, long timeElapsed, TimeUnit unit) + { + // do not log latency if we are within error bars of zero + if (timeElapsed > unit.convert(approxTime.error(), NANOSECONDS)) + internodeLatency.accept(timeElapsed, unit); + } + + @Override + public void onArrived(int messageSize, Header header, long timeElapsed, TimeUnit unit) + { + } + + @Override + public void onArrivedExpired(int messageSize, Header header, boolean wasCorrupt, long timeElapsed, TimeUnit unit) + { + counters.addExpired(messageSize); + + globalMetrics.recordInternodeDroppedMessage(header.verb, timeElapsed, unit); + } + + @Override + public void onArrivedCorrupt(int messageSize, Header header, long timeElapsed, TimeUnit unit) + { + counters.addError(messageSize); + + messageConsumer.fail(header, new Crc.InvalidCrc(0, 0)); // could use one of the original exceptions? + } + + @Override + public void onClosedBeforeArrival(int messageSize, Header header, int bytesReceived, boolean wasCorrupt, boolean wasExpired) + { + counters.addError(messageSize); + + messageConsumer.fail(header, new InvalidSerializedSizeException(header.verb, messageSize, bytesReceived)); + } + + @Override + public void onExpired(int messageSize, Header header, long timeElapsed, TimeUnit unit) + { + counters.addExpired(messageSize); + + globalMetrics.recordInternodeDroppedMessage(header.verb, timeElapsed, unit); + } + + @Override + public void onFailedDeserialize(int messageSize, Header header, Throwable t) + { + counters.addError(messageSize); + + /* + * If an exception is caught during deser, return a failure response immediately + * instead of waiting for the callback on the other end to expire. + */ + messageConsumer.fail(header, t); + } + + @Override + public void onDispatched(int messageSize, Header header) + { + counters.addPending(messageSize); + } + + @Override + public void onExecuting(int messageSize, Header header, long timeElapsed, TimeUnit unit) + { + globalMetrics.recordInternalLatency(header.verb, timeElapsed, unit); + } + + @Override + public void onExecuted(int messageSize, Header header, long timeElapsed, TimeUnit unit) + { + counters.removePending(messageSize); + } + + @Override + public void onProcessed(int messageSize, Header header) + { + counters.addProcessed(messageSize); + } + }; + } + + /* + * Aggregated counters + */ + + InboundCounters countersFor(ConnectionType type) + { + switch (type) + { + case URGENT_MESSAGES: return urgentCounters; + case SMALL_MESSAGES: return smallCounters; + case LARGE_MESSAGES: return largeCounters; + case LEGACY_MESSAGES: return legacyCounters; + } + + throw new IllegalArgumentException(); + } + + public long receivedCount() + { + return sumHandlers(h -> h.receivedCount) + closedReceivedCount; + } + + public long receivedBytes() + { + return sumHandlers(h -> h.receivedBytes) + closedReceivedBytes; + } + + public long throttledCount() + { + return sumHandlers(h -> h.throttledCount) + closedThrottledCount; + } + + public long throttledNanos() + { + return sumHandlers(h -> h.throttledNanos) + closedThrottledNanos; + } + + public long usingCapacity() + { + return sumHandlers(h -> h.queueSize); + } + + public long usingEndpointReserveCapacity() + { + return endpointReserveCapacity.using(); + } + + public long corruptFramesRecovered() + { + return sumHandlers(h -> h.corruptFramesRecovered) + closedCorruptFramesRecovered; + } + + public long corruptFramesUnrecovered() + { + return sumHandlers(h -> h.corruptFramesUnrecovered) + closedCorruptFramesUnrecovered; + } + + public long errorCount() + { + return sumCounters(InboundCounters::errorCount); + } + + public long errorBytes() + { + return sumCounters(InboundCounters::errorBytes); + } + + public long expiredCount() + { + return sumCounters(InboundCounters::expiredCount); + } + + public long expiredBytes() + { + return sumCounters(InboundCounters::expiredBytes); + } + + public long processedCount() + { + return sumCounters(InboundCounters::processedCount); + } + + public long processedBytes() + { + return sumCounters(InboundCounters::processedBytes); + } + + public long scheduledCount() + { + return sumCounters(InboundCounters::scheduledCount); + } + + public long scheduledBytes() + { + return sumCounters(InboundCounters::scheduledBytes); + } + + /* + * 'Archived' counter values, combined for all connections that have been closed. + */ + + private volatile long closedReceivedCount, closedReceivedBytes; + + private static final AtomicLongFieldUpdater closedReceivedCountUpdater = + AtomicLongFieldUpdater.newUpdater(InboundMessageHandlers.class, "closedReceivedCount"); + private static final AtomicLongFieldUpdater closedReceivedBytesUpdater = + AtomicLongFieldUpdater.newUpdater(InboundMessageHandlers.class, "closedReceivedBytes"); + + private volatile long closedThrottledCount, closedThrottledNanos; + + private static final AtomicLongFieldUpdater closedThrottledCountUpdater = + AtomicLongFieldUpdater.newUpdater(InboundMessageHandlers.class, "closedThrottledCount"); + private static final AtomicLongFieldUpdater closedThrottledNanosUpdater = + AtomicLongFieldUpdater.newUpdater(InboundMessageHandlers.class, "closedThrottledNanos"); + + private volatile long closedCorruptFramesRecovered, closedCorruptFramesUnrecovered; + + private static final AtomicLongFieldUpdater closedCorruptFramesRecoveredUpdater = + AtomicLongFieldUpdater.newUpdater(InboundMessageHandlers.class, "closedCorruptFramesRecovered"); + private static final AtomicLongFieldUpdater closedCorruptFramesUnrecoveredUpdater = + AtomicLongFieldUpdater.newUpdater(InboundMessageHandlers.class, "closedCorruptFramesUnrecovered"); + + private void absorbCounters(InboundMessageHandler handler) + { + closedReceivedCountUpdater.addAndGet(this, handler.receivedCount); + closedReceivedBytesUpdater.addAndGet(this, handler.receivedBytes); + + closedThrottledCountUpdater.addAndGet(this, handler.throttledCount); + closedThrottledNanosUpdater.addAndGet(this, handler.throttledNanos); + + closedCorruptFramesRecoveredUpdater.addAndGet(this, handler.corruptFramesRecovered); + closedCorruptFramesUnrecoveredUpdater.addAndGet(this, handler.corruptFramesUnrecovered); + } + + private long sumHandlers(ToLongFunction counter) + { + long sum = 0L; + for (InboundMessageHandler h : handlers) + sum += counter.applyAsLong(h); + return sum; + } + + private long sumCounters(ToLongFunction mapping) + { + return mapping.applyAsLong(urgentCounters) + + mapping.applyAsLong(smallCounters) + + mapping.applyAsLong(largeCounters) + + mapping.applyAsLong(legacyCounters); + } + + interface HandlerProvider + { + InboundMessageHandler provide(FrameDecoder decoder, + + ConnectionType type, + Channel channel, + InetAddressAndPort self, + InetAddressAndPort peer, + int version, + int largeMessageThreshold, + + int queueCapacity, + ResourceLimits.Limit endpointReserveCapacity, + ResourceLimits.Limit globalReserveCapacity, + InboundMessageHandler.WaitQueue endpointWaitQueue, + InboundMessageHandler.WaitQueue globalWaitQueue, + + InboundMessageHandler.OnHandlerClosed onClosed, + InboundMessageCallbacks callbacks, + Consumer> consumer); + } +} diff --git a/src/java/org/apache/cassandra/net/InboundSink.java b/src/java/org/apache/cassandra/net/InboundSink.java new file mode 100644 index 000000000000..df63be2d8bf8 --- /dev/null +++ b/src/java/org/apache/cassandra/net/InboundSink.java @@ -0,0 +1,161 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.cassandra.net; + +import java.io.IOException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; +import java.util.function.Predicate; + +import org.slf4j.LoggerFactory; + +import net.openhft.chronicle.core.util.ThrowingConsumer; +import org.apache.cassandra.db.filter.TombstoneOverwhelmingException; +import org.apache.cassandra.exceptions.RequestFailureReason; +import org.apache.cassandra.index.IndexNotAvailableException; +import org.apache.cassandra.utils.NoSpamLogger; + +/** + * A message sink that all inbound messages go through. + * + * Default sink used by {@link MessagingService} is {@link IVerbHandler#doVerb(Message)}, but it can be overridden + * to filter out certain messages, record the fact of attempted delivery, or delay arrival. + * + * This facility is most useful for test code. + * + * {@link #accept(Message)} is invoked on a thread belonging to the {@link org.apache.cassandra.concurrent.Stage} + * assigned to the {@link Verb} of the message. + */ +public class InboundSink implements InboundMessageHandlers.MessageConsumer +{ + private static final NoSpamLogger noSpamLogger = + NoSpamLogger.getLogger(LoggerFactory.getLogger(InboundSink.class), 1L, TimeUnit.SECONDS); + + private static class Filtered implements ThrowingConsumer, IOException> + { + final Predicate> condition; + final ThrowingConsumer, IOException> next; + + private Filtered(Predicate> condition, ThrowingConsumer, IOException> next) + { + this.condition = condition; + this.next = next; + } + + public void accept(Message message) throws IOException + { + if (condition.test(message)) + next.accept(message); + } + } + + @SuppressWarnings("FieldMayBeFinal") + private volatile ThrowingConsumer, IOException> sink; + private static final AtomicReferenceFieldUpdater sinkUpdater + = AtomicReferenceFieldUpdater.newUpdater(InboundSink.class, ThrowingConsumer.class, "sink"); + + private final MessagingService messaging; + + InboundSink(MessagingService messaging) + { + this.messaging = messaging; + this.sink = message -> message.header.verb.handler().doVerb((Message) message); + } + + public void fail(Message.Header header, Throwable failure) + { + if (header.callBackOnFailure()) + { + Message response = Message.failureResponse(header.id, header.expiresAtNanos, RequestFailureReason.forException(failure)); + messaging.send(response, header.from); + } + } + + public void accept(Message message) + { + try + { + sink.accept(message); + } + catch (Throwable t) + { + fail(message.header, t); + + if (t instanceof TombstoneOverwhelmingException || t instanceof IndexNotAvailableException) + noSpamLogger.error(t.getMessage()); + else if (t instanceof RuntimeException) + throw (RuntimeException) t; + else + throw new RuntimeException(t); + } + } + + public void add(Predicate> allow) + { + sinkUpdater.updateAndGet(this, sink -> new Filtered(allow, sink)); + } + + public void remove(Predicate> allow) + { + sinkUpdater.updateAndGet(this, sink -> without(sink, allow)); + } + + public void clear() + { + sinkUpdater.updateAndGet(this, InboundSink::clear); + } + + @Deprecated // TODO: this is not the correct way to do things + public boolean allow(Message message) + { + return allows(sink, message); + } + + private static ThrowingConsumer, IOException> clear(ThrowingConsumer, IOException> sink) + { + while (sink instanceof Filtered) + sink = ((Filtered) sink).next; + return sink; + } + + private static ThrowingConsumer, IOException> without(ThrowingConsumer, IOException> sink, Predicate> condition) + { + if (!(sink instanceof Filtered)) + return sink; + + Filtered filtered = (Filtered) sink; + ThrowingConsumer, IOException> next = without(filtered.next, condition); + return condition.equals(filtered.condition) ? next + : next == filtered.next + ? sink + : new Filtered(filtered.condition, next); + } + + private static boolean allows(ThrowingConsumer, IOException> sink, Message message) + { + while (sink instanceof Filtered) + { + Filtered filtered = (Filtered) sink; + if (!filtered.condition.test(message)) + return false; + sink = filtered.next; + } + return true; + } + +} diff --git a/src/java/org/apache/cassandra/net/InboundSockets.java b/src/java/org/apache/cassandra/net/InboundSockets.java new file mode 100644 index 000000000000..8f74eaae4139 --- /dev/null +++ b/src/java/org/apache/cassandra/net/InboundSockets.java @@ -0,0 +1,244 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.cassandra.net; + +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.ExecutorService; +import java.util.function.Consumer; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableList; + +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelPipeline; +import io.netty.channel.group.ChannelGroup; +import io.netty.channel.group.DefaultChannelGroup; +import io.netty.util.concurrent.DefaultEventExecutor; +import io.netty.util.concurrent.Future; +import io.netty.util.concurrent.GlobalEventExecutor; +import io.netty.util.concurrent.PromiseNotifier; +import io.netty.util.concurrent.SucceededFuture; +import org.apache.cassandra.concurrent.NamedThreadFactory; +import org.apache.cassandra.config.DatabaseDescriptor; +import org.apache.cassandra.utils.FBUtilities; + +class InboundSockets +{ + /** + * A simple struct to wrap up the components needed for each listening socket. + */ + @VisibleForTesting + static class InboundSocket + { + public final InboundConnectionSettings settings; + + /** + * The base {@link Channel} that is doing the socket listen/accept. + * Null only until open() is invoked and {@link #binding} has yet to complete. + */ + private volatile Channel listen; + /** + * Once open() is invoked, this holds the future result of opening the socket, + * so that its completion can be waited on. Once complete, it sets itself to null. + */ + private volatile ChannelFuture binding; + + // purely to prevent close racing with open + private boolean closedWithoutOpening; + + /** + * A group of the open, inbound {@link Channel}s connected to this node. This is mostly interesting so that all of + * the inbound connections/channels can be closed when the listening socket itself is being closed. + */ + private final ChannelGroup connections; + private final DefaultEventExecutor executor; + + private InboundSocket(InboundConnectionSettings settings) + { + this.settings = settings; + this.executor = new DefaultEventExecutor(new NamedThreadFactory("Listen-" + settings.bindAddress)); + this.connections = new DefaultChannelGroup(settings.bindAddress.toString(), executor); + } + + private Future open() + { + return open(pipeline -> {}); + } + + private Future open(Consumer pipelineInjector) + { + synchronized (this) + { + if (listen != null) + return new SucceededFuture<>(GlobalEventExecutor.INSTANCE, null); + if (binding != null) + return binding; + if (closedWithoutOpening) + throw new IllegalStateException(); + binding = InboundConnectionInitiator.bind(settings, connections, pipelineInjector); + } + + return binding.addListener(ignore -> { + synchronized (this) + { + if (binding.isSuccess()) + listen = binding.channel(); + binding = null; + } + }); + } + + /** + * Close this socket and any connections created on it. Once closed, this socket may not be re-opened. + * + * This may not execute synchronously, so a Future is returned encapsulating its result. + * @param shutdownExecutors + */ + private Future close(Consumer shutdownExecutors) + { + AsyncPromise done = AsyncPromise.uncancellable(GlobalEventExecutor.INSTANCE); + + Runnable close = () -> { + List> closing = new ArrayList<>(); + if (listen != null) + closing.add(listen.close()); + closing.add(connections.close()); + new FutureCombiner(closing) + .addListener(future -> { + executor.shutdownGracefully(); + shutdownExecutors.accept(executor); + }) + .addListener(new PromiseNotifier<>(done)); + }; + + synchronized (this) + { + if (listen == null && binding == null) + { + closedWithoutOpening = true; + return new SucceededFuture<>(GlobalEventExecutor.INSTANCE, null); + } + + if (listen != null) + { + close.run(); + } + else + { + binding.cancel(true); + binding.addListener(future -> close.run()); + } + + return done; + } + } + + public boolean isOpen() + { + return listen != null && listen.isOpen(); + } + } + + private final List sockets; + + InboundSockets(InboundConnectionSettings template) + { + this(withDefaultBindAddresses(template)); + } + + InboundSockets(List templates) + { + this.sockets = bindings(templates); + } + + private static List withDefaultBindAddresses(InboundConnectionSettings template) + { + ImmutableList.Builder templates = ImmutableList.builder(); + templates.add(template.withBindAddress(FBUtilities.getLocalAddressAndPort())); + if (shouldListenOnBroadcastAddress()) + templates.add(template.withBindAddress(FBUtilities.getBroadcastAddressAndPort())); + return templates.build(); + } + + private static List bindings(List templates) + { + ImmutableList.Builder sockets = ImmutableList.builder(); + for (InboundConnectionSettings template : templates) + addBindings(template, sockets); + return sockets.build(); + } + + private static void addBindings(InboundConnectionSettings template, ImmutableList.Builder out) + { + InboundConnectionSettings settings = template.withDefaults(); + out.add(new InboundSocket(settings)); + if (settings.encryption.enable_legacy_ssl_storage_port && settings.encryption.enabled) + out.add(new InboundSocket(template.withLegacyDefaults())); + } + + public Future open(Consumer pipelineInjector) + { + List> opening = new ArrayList<>(); + for (InboundSocket socket : sockets) + opening.add(socket.open(pipelineInjector)); + + return new FutureCombiner(opening); + } + + public Future open() + { + List> opening = new ArrayList<>(); + for (InboundSocket socket : sockets) + opening.add(socket.open()); + return new FutureCombiner(opening); + } + + public boolean isListening() + { + for (InboundSocket socket : sockets) + if (socket.isOpen()) + return true; + return false; + } + + public Future close(Consumer shutdownExecutors) + { + List> closing = new ArrayList<>(); + for (InboundSocket address : sockets) + closing.add(address.close(shutdownExecutors)); + return new FutureCombiner(closing); + } + public Future close() + { + return close(e -> {}); + } + + private static boolean shouldListenOnBroadcastAddress() + { + return DatabaseDescriptor.shouldListenOnBroadcastAddress() + && !FBUtilities.getLocalAddressAndPort().equals(FBUtilities.getBroadcastAddressAndPort()); + } + + @VisibleForTesting + public List sockets() + { + return sockets; + } +} \ No newline at end of file diff --git a/src/java/org/apache/cassandra/net/InvalidSerializedSizeException.java b/src/java/org/apache/cassandra/net/InvalidSerializedSizeException.java new file mode 100644 index 000000000000..5660fd1c3884 --- /dev/null +++ b/src/java/org/apache/cassandra/net/InvalidSerializedSizeException.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.cassandra.net; + +import java.io.IOException; + +import static java.lang.String.format; + +class InvalidSerializedSizeException extends IOException +{ + final Verb verb; + final long expectedSize; + final long actualSizeAtLeast; + + InvalidSerializedSizeException(Verb verb, long expectedSize, long actualSizeAtLeast) + { + super(format("Invalid serialized size; expected %d, actual size at least %d, for verb %s", expectedSize, actualSizeAtLeast, verb)); + this.verb = verb; + this.expectedSize = expectedSize; + this.actualSizeAtLeast = actualSizeAtLeast; + } + + InvalidSerializedSizeException(long expectedSize, long actualSizeAtLeast) + { + super(format("Invalid serialized size; expected %d, actual size at least %d", expectedSize, actualSizeAtLeast)); + this.verb = null; + this.expectedSize = expectedSize; + this.actualSizeAtLeast = actualSizeAtLeast; + } +} diff --git a/src/java/org/apache/cassandra/net/LatencyConsumer.java b/src/java/org/apache/cassandra/net/LatencyConsumer.java new file mode 100644 index 000000000000..3f10d4146a13 --- /dev/null +++ b/src/java/org/apache/cassandra/net/LatencyConsumer.java @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.cassandra.net; + +import java.util.concurrent.TimeUnit; + +public interface LatencyConsumer +{ + void accept(long timeElapsed, TimeUnit unit); +} diff --git a/src/java/org/apache/cassandra/net/LatencySubscribers.java b/src/java/org/apache/cassandra/net/LatencySubscribers.java new file mode 100644 index 000000000000..823e6d0b4917 --- /dev/null +++ b/src/java/org/apache/cassandra/net/LatencySubscribers.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.cassandra.net; + +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; + +import org.apache.cassandra.locator.InetAddressAndPort; + +/** + * Callback that {@link org.apache.cassandra.locator.DynamicEndpointSnitch} listens to in order + * to update host scores. + * + * FIXME: rename/specialise, since only used by DES? + */ +public class LatencySubscribers +{ + public interface Subscriber + { + void receiveTiming(InetAddressAndPort address, long latency, TimeUnit unit); + } + + private volatile Subscriber subscribers; + private static final AtomicReferenceFieldUpdater subscribersUpdater + = AtomicReferenceFieldUpdater.newUpdater(LatencySubscribers.class, Subscriber.class, "subscribers"); + + private static Subscriber merge(Subscriber a, Subscriber b) + { + if (a == null) return b; + if (b == null) return a; + return (address, latency, unit) -> { + a.receiveTiming(address, latency, unit); + b.receiveTiming(address, latency, unit); + }; + } + + public void subscribe(Subscriber subscriber) + { + subscribersUpdater.accumulateAndGet(this, subscriber, LatencySubscribers::merge); + } + + public void add(InetAddressAndPort address, long latency, TimeUnit unit) + { + Subscriber subscribers = this.subscribers; + if (subscribers != null) + subscribers.receiveTiming(address, latency, unit); + } + + /** + * Track latency information for the dynamic snitch + * + * @param cb the callback associated with this message -- this lets us know if it's a message type we're interested in + * @param address the host that replied to the message + */ + public void maybeAdd(RequestCallback cb, InetAddressAndPort address, long latency, TimeUnit unit) + { + if (cb.trackLatencyForSnitch()) + add(address, latency, unit); + } +} diff --git a/src/java/org/apache/cassandra/db/WriteResponse.java b/src/java/org/apache/cassandra/net/LegacyFlag.java similarity index 52% rename from src/java/org/apache/cassandra/db/WriteResponse.java rename to src/java/org/apache/cassandra/net/LegacyFlag.java index 0dddaaba88dd..b2781a1fa7b2 100644 --- a/src/java/org/apache/cassandra/db/WriteResponse.java +++ b/src/java/org/apache/cassandra/net/LegacyFlag.java @@ -15,48 +15,51 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.cassandra.db; +package org.apache.cassandra.net; import java.io.IOException; +import com.google.common.base.Preconditions; + import org.apache.cassandra.io.IVersionedSerializer; import org.apache.cassandra.io.util.DataInputPlus; import org.apache.cassandra.io.util.DataOutputPlus; -import org.apache.cassandra.net.MessageOut; -import org.apache.cassandra.net.MessagingService; -/* - * This empty response is sent by a replica to inform the coordinator that the write succeeded +/** + * Before 4.0 introduced flags field to {@link Message}, we used to encode flags in params field, + * using a dummy value (single byte set to 0). From now on, {@link MessageFlag} should be extended + * instead. + * + * Once 3.0/3.11 compatibility is phased out, this class should be removed. */ -public final class WriteResponse +@Deprecated +final class LegacyFlag { - public static final Serializer serializer = new Serializer(); - - private static final WriteResponse instance = new WriteResponse(); + static final LegacyFlag instance = new LegacyFlag(); - private WriteResponse() + private LegacyFlag() { } - public static MessageOut createMessage() + static IVersionedSerializer serializer = new IVersionedSerializer() { - return new MessageOut<>(MessagingService.Verb.REQUEST_RESPONSE, instance, serializer); - } - - public static class Serializer implements IVersionedSerializer - { - public void serialize(WriteResponse wm, DataOutputPlus out, int version) throws IOException + public void serialize(LegacyFlag param, DataOutputPlus out, int version) throws IOException { + Preconditions.checkArgument(param == instance); + out.write(0); } - public WriteResponse deserialize(DataInputPlus in, int version) throws IOException + public LegacyFlag deserialize(DataInputPlus in, int version) throws IOException { + byte b = in.readByte(); + assert b == 0; return instance; } - public long serializedSize(WriteResponse response, int version) + public long serializedSize(LegacyFlag param, int version) { - return 0; + Preconditions.checkArgument(param == instance); + return 1; } - } + }; } diff --git a/src/java/org/apache/cassandra/net/LegacyLZ4Constants.java b/src/java/org/apache/cassandra/net/LegacyLZ4Constants.java new file mode 100644 index 000000000000..f4fca446fabe --- /dev/null +++ b/src/java/org/apache/cassandra/net/LegacyLZ4Constants.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.cassandra.net; + +abstract class LegacyLZ4Constants +{ + static final int XXHASH_SEED = 0x9747B28C; + + static final int HEADER_LENGTH = 8 // magic number + + 1 // token + + 4 // compressed length + + 4 // uncompressed length + + 4; // checksum + + static final long MAGIC_NUMBER = (long) 'L' << 56 + | (long) 'Z' << 48 + | (long) '4' << 40 + | (long) 'B' << 32 + | 'l' << 24 + | 'o' << 16 + | 'c' << 8 + | 'k'; + + // offsets of header fields + static final int MAGIC_NUMBER_OFFSET = 0; + static final int TOKEN_OFFSET = 8; + static final int COMPRESSED_LENGTH_OFFSET = 9; + static final int UNCOMPRESSED_LENGTH_OFFSET = 13; + static final int CHECKSUM_OFFSET = 17; + + static final int DEFAULT_BLOCK_LENGTH = 1 << 15; // 32 KiB + static final int MAX_BLOCK_LENGTH = 1 << 25; // 32 MiB + + static final int BLOCK_TYPE_NON_COMPRESSED = 0x10; + static final int BLOCK_TYPE_COMPRESSED = 0x20; + + // xxhash to Checksum adapter discards most significant nibble of value ¯\_(ツ)_/¯ + static final int XXHASH_MASK = 0xFFFFFFF; +} diff --git a/src/java/org/apache/cassandra/net/LocalBufferPoolAllocator.java b/src/java/org/apache/cassandra/net/LocalBufferPoolAllocator.java new file mode 100644 index 000000000000..384563f7c5ff --- /dev/null +++ b/src/java/org/apache/cassandra/net/LocalBufferPoolAllocator.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.cassandra.net; + +import java.nio.ByteBuffer; + +import io.netty.channel.EventLoop; +import org.apache.cassandra.utils.memory.BufferPool; + +/** + * Equivalent to {@link GlobalBufferPoolAllocator}, except explicitly using a specified + * {@link org.apache.cassandra.utils.memory.BufferPool.LocalPool} to allocate from. + * + * Exists to facilitate more efficient handling large messages on the inbound path, + * used by {@link ConnectionType#LARGE_MESSAGES} connections. + */ +class LocalBufferPoolAllocator extends BufferPoolAllocator +{ + private final BufferPool.LocalPool pool; + private final EventLoop eventLoop; + + LocalBufferPoolAllocator(EventLoop eventLoop) + { + this.pool = new BufferPool.LocalPool().recycleWhenFree(false); + this.eventLoop = eventLoop; + } + + @Override + ByteBuffer get(int size) + { + if (!eventLoop.inEventLoop()) + throw new IllegalStateException("get() called from outside of owning event loop"); + return pool.get(size, false); + } + + @Override + ByteBuffer getAtLeast(int size) + { + if (!eventLoop.inEventLoop()) + throw new IllegalStateException("getAtLeast() called from outside of owning event loop"); + return pool.get(size, true); + } + + @Override + public void release() + { + pool.release(); + } +} diff --git a/src/java/org/apache/cassandra/net/ManyToOneConcurrentLinkedQueue.java b/src/java/org/apache/cassandra/net/ManyToOneConcurrentLinkedQueue.java new file mode 100644 index 000000000000..4c73bdc9cd2e --- /dev/null +++ b/src/java/org/apache/cassandra/net/ManyToOneConcurrentLinkedQueue.java @@ -0,0 +1,350 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.cassandra.net; + +import java.util.Collection; +import java.util.Iterator; +import java.util.NoSuchElementException; +import java.util.Queue; +import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; +import java.util.function.Consumer; + +/** + * A concurrent many-producers-to-single-consumer linked queue. + * + * Based roughly on {@link java.util.concurrent.ConcurrentLinkedQueue}, except with simpler/cheaper consumer-side + * method implementations ({@link #poll()}, {@link #remove()}, {@link #drain(Consumer)}), and padding added + * to prevent false sharing. + * + * {@link #offer(Object)} provides volatile visibility semantics. {@link #offer(Object)} is lock-free, {@link #poll()} + * and all related consumer methods are wait-free. + * + * In addition to that, provides a {@link #relaxedPeekLastAndOffer(Object)} method that we use to avoid a CAS when + * putting message handlers onto the wait queue. + */ +class ManyToOneConcurrentLinkedQueue extends ManyToOneConcurrentLinkedQueueHead implements Queue +{ + @SuppressWarnings("unused") // pad two cache lines after the head to prevent false sharing + protected long p31, p32, p33, p34, p35, p36, p37, p38, p39, p40, p41, p42, p43, p44, p45; + + ManyToOneConcurrentLinkedQueue() + { + head = tail = new Node<>(null); + } + + /** + * See {@link #relaxedIsEmpty()}. + */ + @Override + public boolean isEmpty() + { + return relaxedIsEmpty(); + } + + /** + * When invoked by the consumer thread, the answer will always be accurate. + * When invoked by a non-consumer thread, it won't always be the case: + * - {@code true} result indicates that the queue IS empty, no matter what; + * - {@code false} result indicates that the queue MIGHT BE non-empty - the value of {@code head} might + * not yet have been made externally visible by the consumer thread. + */ + boolean relaxedIsEmpty() + { + return null == head.next; + } + + @Override + public int size() + { + int size = 0; + Node next = head; + while (null != (next = next.next)) + size++; + return size; + } + + @Override + public E peek() + { + Node next = head.next; + if (null == next) + return null; + return next.item; + } + + @Override + public E element() + { + E item = peek(); + if (null == item) + throw new NoSuchElementException("Queue is empty"); + return item; + } + + @Override + public E poll() + { + Node head = this.head; + Node next = head.next; + + if (null == next) + return null; + + this.lazySetHead(next); // update head reference to next before making previous head node unreachable, + head.lazySetNext(head); // to maintain the guarantee of tail being always reachable from head + + E item = next.item; + next.item = null; + return item; + } + + @Override + public E remove() + { + E item = poll(); + if (null == item) + throw new NoSuchElementException("Queue is empty"); + return item; + } + + @Override + public boolean remove(Object o) + { + if (null == o) + throw new NullPointerException(); + + Node prev = this.head; + Node next = prev.next; + + while (null != next) + { + if (o.equals(next.item)) + { + prev.lazySetNext(next.next); // update prev reference to next before making removed node unreachable, + next.lazySetNext(next); // to maintain the guarantee of tail being always reachable from head + + next.item = null; + return true; + } + + prev = next; + next = next.next; + } + + return false; + } + + /** + * Consume the queue in its entirety and feed every item to the provided {@link Consumer}. + * + * Exists primarily for convenience, and essentially just wraps {@link #poll()} in a loop. + * Yields no performance benefit over invoking {@link #poll()} manually - there just isn't + * anything to meaningfully amortise on the consumer side of this queue. + */ + void drain(Consumer consumer) + { + E item; + while ((item = poll()) != null) + consumer.accept(item); + } + + @Override + public boolean add(E e) + { + return offer(e); + } + + @Override + public boolean offer(E e) + { + internalOffer(e); return true; + } + + /** + * Adds the element to the queue and returns the item of the previous tail node. + * It's possible for the returned item to already have been consumed. + * + * @return previously last tail item in the queue, potentially stale + */ + E relaxedPeekLastAndOffer(E e) + { + return internalOffer(e); + } + + /** + * internalOffer() is based on {@link java.util.concurrent.ConcurrentLinkedQueue#offer(Object)}, + * written by Doug Lea and Martin Buchholz with assistance from members of JCP JSR-166 Expert Group + * and released to the public domain, as explained at http://creativecommons.org/publicdomain/zero/1.0/ + */ + private E internalOffer(E e) + { + if (null == e) + throw new NullPointerException(); + + final Node node = new Node<>(e); + + for (Node t = tail, p = t;;) + { + Node q = p.next; + if (q == null) + { + // p is last node + if (p.casNext(null, node)) + { + // successful CAS is the linearization point for e to become an element of this queue and for node to become "live". + if (p != t) // hop two nodes at a time + casTail(t, node); // failure is ok + return p.item; + } + // lost CAS race to another thread; re-read next + } + else if (p == q) + { + /* + * We have fallen off list. If tail is unchanged, it will also be off-list, in which case we need to + * jump to head, from which all live nodes are always reachable. Else the new tail is a better bet. + */ + p = (t != (t = tail)) ? t : head; + } + else + { + // check for tail updates after two hops + p = (p != t && t != (t = tail)) ? t : q; + } + } + } + + @Override + public boolean contains(Object o) + { + throw new UnsupportedOperationException(); + } + + @Override + public Iterator iterator() + { + throw new UnsupportedOperationException(); + } + + @Override + public Object[] toArray() + { + throw new UnsupportedOperationException(); + } + + @Override + public T[] toArray(T[] a) + { + throw new UnsupportedOperationException(); + } + + @Override + public boolean containsAll(Collection c) + { + throw new UnsupportedOperationException(); + } + + @Override + public boolean addAll(Collection c) + { + throw new UnsupportedOperationException(); + } + + @Override + public boolean removeAll(Collection c) + { + throw new UnsupportedOperationException(); + } + + @Override + public boolean retainAll(Collection c) + { + throw new UnsupportedOperationException(); + } + + @Override + public void clear() + { + throw new UnsupportedOperationException(); + } +} + +class ManyToOneConcurrentLinkedQueueHead extends ManyToOneConcurrentLinkedQueuePadding2 +{ + protected volatile ManyToOneConcurrentLinkedQueue.Node head; + + private static final AtomicReferenceFieldUpdater headUpdater = + AtomicReferenceFieldUpdater.newUpdater(ManyToOneConcurrentLinkedQueueHead.class, Node.class, "head"); + + @SuppressWarnings("WeakerAccess") + protected void lazySetHead(Node val) + { + headUpdater.lazySet(this, val); + } +} + +class ManyToOneConcurrentLinkedQueuePadding2 extends ManyToOneConcurrentLinkedQueueTail +{ + @SuppressWarnings("unused") // pad two cache lines between tail and head to prevent false sharing + protected long p16, p17, p18, p19, p20, p21, p22, p23, p24, p25, p26, p27, p28, p29, p30; +} + +class ManyToOneConcurrentLinkedQueueTail extends ManyToOneConcurrentLinkedQueuePadding1 +{ + protected volatile ManyToOneConcurrentLinkedQueue.Node tail; + + private static final AtomicReferenceFieldUpdater tailUpdater = + AtomicReferenceFieldUpdater.newUpdater(ManyToOneConcurrentLinkedQueueTail.class, Node.class, "tail"); + + @SuppressWarnings({ "WeakerAccess", "UnusedReturnValue" }) + protected boolean casTail(Node expect, Node update) + { + return tailUpdater.compareAndSet(this, expect, update); + } +} + +class ManyToOneConcurrentLinkedQueuePadding1 +{ + @SuppressWarnings("unused") // pad two cache lines before the tail to prevent false sharing + protected long p1, p2, p3, p4, p5, p6, p7, p8, p9, p10, p11, p12, p13, p14, p15; + + static final class Node + { + E item; + volatile Node next; + + private static final AtomicReferenceFieldUpdater nextUpdater = + AtomicReferenceFieldUpdater.newUpdater(Node.class, Node.class, "next"); + + Node(E item) + { + this.item = item; + } + + @SuppressWarnings("SameParameterValue") + boolean casNext(Node expect, Node update) + { + return nextUpdater.compareAndSet(this, expect, update); + } + + void lazySetNext(Node val) + { + nextUpdater.lazySet(this, val); + } + } +} diff --git a/src/java/org/apache/cassandra/net/Message.java b/src/java/org/apache/cassandra/net/Message.java new file mode 100644 index 000000000000..05c1bfaa7b1a --- /dev/null +++ b/src/java/org/apache/cassandra/net/Message.java @@ -0,0 +1,1338 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.cassandra.net; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.EnumMap; +import java.util.Map; +import java.util.UUID; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; + +import javax.annotation.Nullable; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.primitives.Ints; + +import org.apache.cassandra.config.DatabaseDescriptor; +import org.apache.cassandra.exceptions.RequestFailureReason; +import org.apache.cassandra.io.IVersionedAsymmetricSerializer; +import org.apache.cassandra.io.IVersionedSerializer; +import org.apache.cassandra.io.util.DataInputBuffer; +import org.apache.cassandra.io.util.DataInputPlus; +import org.apache.cassandra.io.util.DataOutputPlus; +import org.apache.cassandra.locator.InetAddressAndPort; +import org.apache.cassandra.tracing.Tracing; +import org.apache.cassandra.tracing.Tracing.TraceType; +import org.apache.cassandra.utils.FBUtilities; +import org.apache.cassandra.utils.MonotonicClockTranslation; + +import static java.util.concurrent.TimeUnit.MINUTES; +import static java.util.concurrent.TimeUnit.NANOSECONDS; +import static org.apache.cassandra.db.TypeSizes.sizeof; +import static org.apache.cassandra.db.TypeSizes.sizeofUnsignedVInt; +import static org.apache.cassandra.locator.InetAddressAndPort.Serializer.inetAddressAndPortSerializer; +import static org.apache.cassandra.net.MessagingService.VERSION_3014; +import static org.apache.cassandra.net.MessagingService.VERSION_30; +import static org.apache.cassandra.net.MessagingService.VERSION_40; +import static org.apache.cassandra.net.MessagingService.instance; +import static org.apache.cassandra.utils.MonotonicClock.approxTime; +import static org.apache.cassandra.utils.vint.VIntCoding.computeUnsignedVIntSize; +import static org.apache.cassandra.utils.vint.VIntCoding.getUnsignedVInt; +import static org.apache.cassandra.utils.vint.VIntCoding.skipUnsignedVInt; + +/** + * Immutable main unit of internode communication - what used to be {@code MessageIn} and {@code MessageOut} fused + * in one class. + * + * @param The type of the message payload. + */ +public class Message +{ + public final Header header; + public final T payload; + + private Message(Header header, T payload) + { + this.header = header; + this.payload = payload; + } + + /** Sender of the message. */ + public InetAddressAndPort from() + { + return header.from; + } + + /** Whether the message has crossed the node boundary, that is whether it originated from another node. */ + public boolean isCrossNode() + { + return !from().equals(FBUtilities.getBroadcastAddressAndPort()); + } + + /** + * id of the request/message. In 4.0+ can be shared between multiple messages of the same logical request, + * whilst in versions above a new id would be allocated for each message sent. + */ + public long id() + { + return header.id; + } + + public Verb verb() + { + return header.verb; + } + + boolean isFailureResponse() + { + return verb() == Verb.FAILURE_RSP; + } + + /** + * Creation time of the message. If cross-node timeouts are enabled ({@link DatabaseDescriptor#hasCrossNodeTimeout()}, + * {@code deserialize()} will use the marshalled value, otherwise will use current time on the deserializing machine. + */ + public long createdAtNanos() + { + return header.createdAtNanos; + } + + public long expiresAtNanos() + { + return header.expiresAtNanos; + } + + /** For how long the message has lived. */ + public long elapsedSinceCreated(TimeUnit units) + { + return units.convert(approxTime.now() - createdAtNanos(), NANOSECONDS); + } + + public long creationTimeMillis() + { + return approxTime.translate().toMillisSinceEpoch(createdAtNanos()); + } + + /** Whether a failure response should be returned upon failure */ + boolean callBackOnFailure() + { + return header.callBackOnFailure(); + } + + /** See CASSANDRA-14145 */ + public boolean trackRepairedData() + { + return header.trackRepairedData(); + } + + /** Used for cross-DC write optimisation - pick one node in the DC and have it relay the write to its local peers */ + @Nullable + public ForwardingInfo forwardTo() + { + return header.forwardTo(); + } + + /** The originator of the request - used when forwarding and will differ from {@link #from()} */ + @Nullable + public InetAddressAndPort respondTo() + { + return header.respondTo(); + } + + @Nullable + public UUID traceSession() + { + return header.traceSession(); + } + + @Nullable + public TraceType traceType() + { + return header.traceType(); + } + + /* + * request/response convenience + */ + + /** + * Make a request {@link Message} with supplied verb and payload. Will fill in remaining fields + * automatically. + * + * If you know that you will need to set some params or flags - prefer using variants of {@code out()} + * that allow providing them at point of message constructions, rather than allocating new messages + * with those added flags and params. See {@code outWithFlag()}, {@code outWithFlags()}, and {@code outWithParam()} + * family. + */ + public static Message out(Verb verb, T payload) + { + assert !verb.isResponse(); + + return outWithParam(nextId(), verb, payload, null, null); + } + + public static Message outWithFlag(Verb verb, T payload, MessageFlag flag) + { + assert !verb.isResponse(); + return outWithParam(nextId(), verb, 0, payload, flag.addTo(0), null, null); + } + + public static Message outWithFlags(Verb verb, T payload, MessageFlag flag1, MessageFlag flag2) + { + assert !verb.isResponse(); + return outWithParam(nextId(), verb, 0, payload, flag2.addTo(flag1.addTo(0)), null, null); + } + + static Message outWithParam(long id, Verb verb, T payload, ParamType paramType, Object paramValue) + { + return outWithParam(id, verb, 0, payload, paramType, paramValue); + } + + private static Message outWithParam(long id, Verb verb, long expiresAtNanos, T payload, ParamType paramType, Object paramValue) + { + return outWithParam(id, verb, expiresAtNanos, payload, 0, paramType, paramValue); + } + + private static Message outWithParam(long id, Verb verb, long expiresAtNanos, T payload, int flags, ParamType paramType, Object paramValue) + { + if (payload == null) + throw new IllegalArgumentException(); + + InetAddressAndPort from = FBUtilities.getBroadcastAddressAndPort(); + long createdAtNanos = approxTime.now(); + if (expiresAtNanos == 0) + expiresAtNanos = verb.expiresAtNanos(createdAtNanos); + + return new Message<>(new Header(id, verb, from, createdAtNanos, expiresAtNanos, flags, buildParams(paramType, paramValue)), payload); + } + + public static Message internalResponse(Verb verb, T payload) + { + assert verb.isResponse(); + return outWithParam(0, verb, payload, null, null); + } + + /** Builds a response Message with provided payload, and all the right fields inferred from request Message */ + public Message responseWith(T payload) + { + return outWithParam(id(), verb().responseVerb, expiresAtNanos(), payload, null, null); + } + + /** Builds a response Message with no payload, and all the right fields inferred from request Message */ + public Message emptyResponse() + { + return responseWith(NoPayload.noPayload); + } + + /** Builds a failure response Message with an explicit reason, and fields inferred from request Message */ + public Message failureResponse(RequestFailureReason reason) + { + return failureResponse(id(), expiresAtNanos(), reason); + } + + static Message failureResponse(long id, long expiresAtNanos, RequestFailureReason reason) + { + return outWithParam(id, Verb.FAILURE_RSP, expiresAtNanos, reason, null, null); + } + + Message withCallBackOnFailure() + { + return new Message<>(header.withFlag(MessageFlag.CALL_BACK_ON_FAILURE), payload); + } + + public Message withForwardTo(ForwardingInfo peers) + { + return new Message<>(header.withParam(ParamType.FORWARD_TO, peers), payload); + } + + private static final EnumMap NO_PARAMS = new EnumMap<>(ParamType.class); + + private static Map buildParams(ParamType type, Object value) + { + Map params = NO_PARAMS; + if (Tracing.isTracing()) + params = Tracing.instance.addTraceHeaders(new EnumMap<>(ParamType.class)); + + if (type != null) + { + if (params.isEmpty()) + params = new EnumMap<>(ParamType.class); + params.put(type, value); + } + + return params; + } + + private static Map addParam(Map params, ParamType type, Object value) + { + if (type == null) + return params; + + params = new EnumMap<>(params); + params.put(type, value); + return params; + } + + /* + * id generation + */ + + private static final long NO_ID = 0L; // this is a valid ID for pre40 nodes + + private static final AtomicInteger nextId = new AtomicInteger(0); + + private static long nextId() + { + long id; + do + { + id = nextId.incrementAndGet(); + } + while (id == NO_ID); + + return id; + } + + /** + * WARNING: this is inaccurate for messages from pre40 nodes, which can use 0 as an id (but will do so rarely) + */ + @VisibleForTesting + boolean hasId() + { + return id() != NO_ID; + } + + /** we preface every message with this number so the recipient can validate the sender is sane */ + static final int PROTOCOL_MAGIC = 0xCA552DFA; + + static void validateLegacyProtocolMagic(int magic) throws InvalidLegacyProtocolMagic + { + if (magic != PROTOCOL_MAGIC) + throw new InvalidLegacyProtocolMagic(magic); + } + + public static final class InvalidLegacyProtocolMagic extends IOException + { + public final int read; + private InvalidLegacyProtocolMagic(int read) + { + super(String.format("Read %d, Expected %d", read, PROTOCOL_MAGIC)); + this.read = read; + } + } + + public String toString() + { + return "(from:" + from() + ", type:" + verb().stage + " verb:" + verb() + ')'; + } + + /** + * Split into a separate object to allow partial message deserialization without wasting work and allocation + * afterwards, if the entire message is necessary and available. + */ + public static class Header + { + public final long id; + public final Verb verb; + public final InetAddressAndPort from; + public final long createdAtNanos; + public final long expiresAtNanos; + private final int flags; + private final Map params; + + private Header(long id, Verb verb, InetAddressAndPort from, long createdAtNanos, long expiresAtNanos, int flags, Map params) + { + this.id = id; + this.verb = verb; + this.from = from; + this.createdAtNanos = createdAtNanos; + this.expiresAtNanos = expiresAtNanos; + this.flags = flags; + this.params = params; + } + + Header withFlag(MessageFlag flag) + { + return new Header(id, verb, from, createdAtNanos, expiresAtNanos, flag.addTo(flags), params); + } + + Header withParam(ParamType type, Object value) + { + return new Header(id, verb, from, createdAtNanos, expiresAtNanos, flags, addParam(params, type, value)); + } + + boolean callBackOnFailure() + { + return MessageFlag.CALL_BACK_ON_FAILURE.isIn(flags); + } + + boolean trackRepairedData() + { + return MessageFlag.TRACK_REPAIRED_DATA.isIn(flags); + } + + @Nullable + ForwardingInfo forwardTo() + { + return (ForwardingInfo) params.get(ParamType.FORWARD_TO); + } + + @Nullable + InetAddressAndPort respondTo() + { + return (InetAddressAndPort) params.get(ParamType.RESPOND_TO); + } + + @Nullable + public UUID traceSession() + { + return (UUID) params.get(ParamType.TRACE_SESSION); + } + + @Nullable + public TraceType traceType() + { + return (TraceType) params.getOrDefault(ParamType.TRACE_TYPE, TraceType.QUERY); + } + } + + @SuppressWarnings("WeakerAccess") + public static class Builder + { + private Verb verb; + private InetAddressAndPort from; + private T payload; + private int flags = 0; + private final Map params = new EnumMap<>(ParamType.class); + private long createdAtNanos; + private long expiresAtNanos; + private long id; + + private boolean hasId; + + private Builder() + { + } + + public Builder from(InetAddressAndPort from) + { + this.from = from; + return this; + } + + public Builder withPayload(T payload) + { + this.payload = payload; + return this; + } + + public Builder withFlag(MessageFlag flag) + { + flags = flag.addTo(flags); + return this; + } + + public Builder withFlags(int flags) + { + this.flags = flags; + return this; + } + + public Builder withParam(ParamType type, Object value) + { + params.put(type, value); + return this; + } + + public Builder withoutParam(ParamType type) + { + params.remove(type); + return this; + } + + public Builder withParams(Map params) + { + this.params.putAll(params); + return this; + } + + public Builder ofVerb(Verb verb) + { + this.verb = verb; + if (expiresAtNanos == 0 && verb != null && createdAtNanos != 0) + expiresAtNanos = verb.expiresAtNanos(createdAtNanos); + if (!this.verb.isResponse() && from == null) // default to sending from self if we're a request verb + from = FBUtilities.getBroadcastAddressAndPort(); + return this; + } + + public Builder withCreatedAt(long createdAtNanos) + { + this.createdAtNanos = createdAtNanos; + if (expiresAtNanos == 0 && verb != null) + expiresAtNanos = verb.expiresAtNanos(createdAtNanos); + return this; + } + + public Builder withExpiresAt(long expiresAtNanos) + { + this.expiresAtNanos = expiresAtNanos; + return this; + } + + public Builder withId(long id) + { + this.id = id; + hasId = true; + return this; + } + + public Message build() + { + if (verb == null) + throw new IllegalArgumentException(); + if (from == null) + throw new IllegalArgumentException(); + if (payload == null) + throw new IllegalArgumentException(); + + return new Message<>(new Header(hasId ? id : nextId(), verb, from, createdAtNanos, expiresAtNanos, flags, params), payload); + } + } + + public static Builder builder(Message message) + { + return new Builder().from(message.from()) + .withId(message.id()) + .ofVerb(message.verb()) + .withCreatedAt(message.createdAtNanos()) + .withExpiresAt(message.expiresAtNanos()) + .withFlags(message.header.flags) + .withParams(message.header.params) + .withPayload(message.payload); + } + + public static Builder builder(Verb verb, T payload) + { + return new Builder().ofVerb(verb) + .withCreatedAt(approxTime.now()) + .withPayload(payload); + } + + public static final Serializer serializer = new Serializer(); + + /** + * Each message contains a header with several fixed fields, an optional key-value params section, and then + * the message payload itself. Below is a visualization of the layout. + * + * The params are prefixed by the count of key-value pairs; this value is encoded as unsigned vint. + * An individual param has an unsvint id (more specifically, a {@link ParamType}), and a byte array value. + * The param value is prefixed with it's length, encoded as an unsigned vint, followed by by the value's bytes. + * + * Legacy Notes (see {@link Serializer#serialize(Message, DataOutputPlus, int)} for complete details): + * - pre 4.0, the IP address was sent along in the header, before the verb. The IP address may be either IPv4 (4 bytes) or IPv6 (16 bytes) + * - pre-4.0, the verb was encoded as a 4-byte integer; in 4.0 and up it is an unsigned vint + * - pre-4.0, the payloadSize was encoded as a 4-byte integer; in 4.0 and up it is an unsigned vint + * - pre-4.0, the count of param key-value pairs was encoded as a 4-byte integer; in 4.0 and up it is an unsigned vint + * - pre-4.0, param names were encoded as strings; in 4.0 they are encoded as enum id vints + * - pre-4.0, expiry time wasn't encoded at all; in 4.0 it's an unsigned vint + * - pre-4.0, message id was an int; in 4.0 and up it's an unsigned vint + * - pre-4.0, messages included PROTOCOL MAGIC BYTES; post-4.0, we rely on frame CRCs instead + * - pre-4.0, messages would serialize boolean params as dummy ONE_BYTEs; post-4.0 we have a dedicated 'flags' vint + * + *
+     * {@code
+     *            1 1 1 1 1 2 2 2 2 2 3
+     *  0 2 4 6 8 0 2 4 6 8 0 2 4 6 8 0
+     * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+     * | Message ID (vint)             |
+     * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+     * | Creation timestamp (int)      |
+     * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+     * | Expiry (vint)                 |
+     * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+     * | Verb (vint)                   |
+     * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+     * | Flags (vint)                  |
+     * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+     * | Param count (vint)            |
+     * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+     * |                               /
+     * /           Params              /
+     * /                               |
+     * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+     * | Payload size (vint)           |
+     * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+     * |                               /
+     * /           Payload             /
+     * /                               |
+     * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+     * }
+     * 
+ */ + public static final class Serializer + { + private static final int CREATION_TIME_SIZE = 4; + + private Serializer() + { + } + + public void serialize(Message message, DataOutputPlus out, int version) throws IOException + { + if (version >= VERSION_40) + serializePost40(message, out, version); + else + serializePre40(message, out, version); + } + + public Message deserialize(DataInputPlus in, InetAddressAndPort peer, int version) throws IOException + { + return version >= VERSION_40 ? deserializePost40(in, peer, version) : deserializePre40(in, version); + } + + /** + * A partial variant of deserialize, taking in a previously deserialized {@link Header} as an argument. + * + * Skip deserializing the {@link Header} from the input stream in favour of using the provided header. + */ + public Message deserialize(DataInputPlus in, Header header, int version) throws IOException + { + return version >= VERSION_40 ? deserializePost40(in, header, version) : deserializePre40(in, header, version); + } + + private int serializedSize(Message message, int version) + { + return version >= VERSION_40 ? serializedSizePost40(message, version) : serializedSizePre40(message, version); + } + + /** + * Size of the next message in the stream. Returns -1 if there aren't sufficient bytes read yet to determine size. + */ + int inferMessageSize(ByteBuffer buf, int index, int limit, int version) throws InvalidLegacyProtocolMagic + { + int size = version >= VERSION_40 ? inferMessageSizePost40(buf, index, limit) : inferMessageSizePre40(buf, index, limit); + if (size > DatabaseDescriptor.getInternodeMaxMessageSizeInBytes()) + throw new OversizedMessageException(size); + return size; + } + + /** + * Partially deserialize the message - by only extracting the header and leaving the payload alone. + * + * To get the rest of the message without repeating the work done here, use {@link #deserialize(DataInputPlus, Header, int)} + * method. + * + * It's assumed that the provided buffer contains all the bytes necessary to deserialize the header fully. + */ + Header extractHeader(ByteBuffer buf, InetAddressAndPort from, long currentTimeNanos, int version) throws IOException + { + return version >= VERSION_40 + ? extractHeaderPost40(buf, from, currentTimeNanos, version) + : extractHeaderPre40(buf, currentTimeNanos, version); + } + + private static long getExpiresAtNanos(long createdAtNanos, long currentTimeNanos, long expirationPeriodNanos) + { + if (!DatabaseDescriptor.hasCrossNodeTimeout() || createdAtNanos > currentTimeNanos) + createdAtNanos = currentTimeNanos; + return createdAtNanos + expirationPeriodNanos; + } + + /* + * 4.0 ser/deser + */ + + private void serializeHeaderPost40(Header header, DataOutputPlus out, int version) throws IOException + { + out.writeUnsignedVInt(header.id); + // int cast cuts off the high-order half of the timestamp, which we can assume remains + // the same between now and when the recipient reconstructs it. + out.writeInt((int) approxTime.translate().toMillisSinceEpoch(header.createdAtNanos)); + out.writeUnsignedVInt(1 + NANOSECONDS.toMillis(header.expiresAtNanos - header.createdAtNanos)); + out.writeUnsignedVInt(header.verb.id); + out.writeUnsignedVInt(header.flags); + serializeParams(header.params, out, version); + } + + private Header deserializeHeaderPost40(DataInputPlus in, InetAddressAndPort peer, int version) throws IOException + { + long id = in.readUnsignedVInt(); + long currentTimeNanos = approxTime.now(); + MonotonicClockTranslation timeSnapshot = approxTime.translate(); + long creationTimeNanos = calculateCreationTimeNanos(in.readInt(), timeSnapshot, currentTimeNanos); + long expiresAtNanos = getExpiresAtNanos(creationTimeNanos, currentTimeNanos, TimeUnit.MILLISECONDS.toNanos(in.readUnsignedVInt())); + Verb verb = Verb.fromId(Ints.checkedCast(in.readUnsignedVInt())); + int flags = Ints.checkedCast(in.readUnsignedVInt()); + Map params = deserializeParams(in, version); + return new Header(id, verb, peer, creationTimeNanos, expiresAtNanos, flags, params); + } + + private void skipHeaderPost40(DataInputPlus in) throws IOException + { + skipUnsignedVInt(in); // id + in.skipBytesFully(4); // createdAt + skipUnsignedVInt(in); // expiresIn + skipUnsignedVInt(in); // verb + skipUnsignedVInt(in); // flags + skipParamsPost40(in); // params + } + + private int serializedHeaderSizePost40(Header header, int version) + { + long size = 0; + size += sizeofUnsignedVInt(header.id); + size += CREATION_TIME_SIZE; + size += sizeofUnsignedVInt(1 + NANOSECONDS.toMillis(header.expiresAtNanos - header.createdAtNanos)); + size += sizeofUnsignedVInt(header.verb.id); + size += sizeofUnsignedVInt(header.flags); + size += serializedParamsSize(header.params, version); + return Ints.checkedCast(size); + } + + private Header extractHeaderPost40(ByteBuffer buf, InetAddressAndPort from, long currentTimeNanos, int version) throws IOException + { + MonotonicClockTranslation timeSnapshot = approxTime.translate(); + + int index = buf.position(); + + long id = getUnsignedVInt(buf, index); + index += computeUnsignedVIntSize(id); + + int createdAtMillis = buf.getInt(index); + index += sizeof(createdAtMillis); + + long expiresInMillis = getUnsignedVInt(buf, index); + index += computeUnsignedVIntSize(expiresInMillis); + + Verb verb = Verb.fromId(Ints.checkedCast(getUnsignedVInt(buf, index))); + index += computeUnsignedVIntSize(verb.id); + + int flags = Ints.checkedCast(getUnsignedVInt(buf, index)); + index += computeUnsignedVIntSize(flags); + + Map params = extractParams(buf, index, version); + + long createdAtNanos = calculateCreationTimeNanos(createdAtMillis, timeSnapshot, currentTimeNanos); + long expiresAtNanos = getExpiresAtNanos(createdAtNanos, currentTimeNanos, TimeUnit.MILLISECONDS.toNanos(expiresInMillis)); + + return new Header(id, verb, from, createdAtNanos, expiresAtNanos, flags, params); + } + + private void serializePost40(Message message, DataOutputPlus out, int version) throws IOException + { + serializeHeaderPost40(message.header, out, version); + out.writeUnsignedVInt(message.payloadSize(version)); + message.verb().serializer().serialize(message.payload, out, version); + } + + private Message deserializePost40(DataInputPlus in, InetAddressAndPort peer, int version) throws IOException + { + Header header = deserializeHeaderPost40(in, peer, version); + skipUnsignedVInt(in); // payload size, not needed by payload deserializer + T payload = (T) header.verb.serializer().deserialize(in, version); + return new Message<>(header, payload); + } + + private Message deserializePost40(DataInputPlus in, Header header, int version) throws IOException + { + skipHeaderPost40(in); + skipUnsignedVInt(in); // payload size, not needed by payload deserializer + T payload = (T) header.verb.serializer().deserialize(in, version); + return new Message<>(header, payload); + } + + private int serializedSizePost40(Message message, int version) + { + long size = 0; + size += serializedHeaderSizePost40(message.header, version); + int payloadSize = message.payloadSize(version); + size += sizeofUnsignedVInt(payloadSize) + payloadSize; + return Ints.checkedCast(size); + } + + private int inferMessageSizePost40(ByteBuffer buf, int readerIndex, int readerLimit) + { + int index = readerIndex; + + int idSize = computeUnsignedVIntSize(buf, index, readerLimit); + if (idSize < 0) + return -1; // not enough bytes to read id + index += idSize; + + index += CREATION_TIME_SIZE; + if (index > readerLimit) + return -1; + + int expirationSize = computeUnsignedVIntSize(buf, index, readerLimit); + if (expirationSize < 0) + return -1; + index += expirationSize; + + int verbIdSize = computeUnsignedVIntSize(buf, index, readerLimit); + if (verbIdSize < 0) + return -1; + index += verbIdSize; + + int flagsSize = computeUnsignedVIntSize(buf, index, readerLimit); + if (flagsSize < 0) + return -1; + index += flagsSize; + + int paramsSize = extractParamsSizePost40(buf, index, readerLimit); + if (paramsSize < 0) + return -1; + index += paramsSize; + + long payloadSize = getUnsignedVInt(buf, index, readerLimit); + if (payloadSize < 0) + return -1; + index += computeUnsignedVIntSize(payloadSize) + payloadSize; + + return index - readerIndex; + } + + /* + * legacy ser/deser + */ + + private void serializeHeaderPre40(Header header, DataOutputPlus out, int version) throws IOException + { + out.writeInt(PROTOCOL_MAGIC); + out.writeInt(Ints.checkedCast(header.id)); + // int cast cuts off the high-order half of the timestamp, which we can assume remains + // the same between now and when the recipient reconstructs it. + out.writeInt((int) approxTime.translate().toMillisSinceEpoch(header.createdAtNanos)); + inetAddressAndPortSerializer.serialize(header.from, out, version); + out.writeInt(header.verb.toPre40Verb().id); + serializeParams(addFlagsToLegacyParams(header.params, header.flags), out, version); + } + + private Header deserializeHeaderPre40(DataInputPlus in, int version) throws IOException + { + validateLegacyProtocolMagic(in.readInt()); + int id = in.readInt(); + long currentTimeNanos = approxTime.now(); + MonotonicClockTranslation timeSnapshot = approxTime.translate(); + long creationTimeNanos = calculateCreationTimeNanos(in.readInt(), timeSnapshot, currentTimeNanos); + InetAddressAndPort from = inetAddressAndPortSerializer.deserialize(in, version); + Verb verb = Verb.fromId(in.readInt()); + Map params = deserializeParams(in, version); + int flags = removeFlagsFromLegacyParams(params); + return new Header(id, verb, from, creationTimeNanos, verb.expiresAtNanos(creationTimeNanos), flags, params); + } + + private static final int PRE_40_MESSAGE_PREFIX_SIZE = 12; // protocol magic + id + createdAt + + private void skipHeaderPre40(DataInputPlus in) throws IOException + { + in.skipBytesFully(PRE_40_MESSAGE_PREFIX_SIZE); // magic, id, createdAt + in.skipBytesFully(in.readByte()); // from + in.skipBytesFully(4); // verb + skipParamsPre40(in); // params + } + + private int serializedHeaderSizePre40(Header header, int version) + { + long size = 0; + size += PRE_40_MESSAGE_PREFIX_SIZE; + size += inetAddressAndPortSerializer.serializedSize(header.from, version); + size += sizeof(header.verb.id); + size += serializedParamsSize(addFlagsToLegacyParams(header.params, header.flags), version); + return Ints.checkedCast(size); + } + + private Header extractHeaderPre40(ByteBuffer buf, long currentTimeNanos, int version) throws IOException + { + MonotonicClockTranslation timeSnapshot = approxTime.translate(); + + int index = buf.position(); + + index += 4; // protocol magic + + long id = buf.getInt(index); + index += 4; + + int createdAtMillis = buf.getInt(index); + index += 4; + + InetAddressAndPort from = inetAddressAndPortSerializer.extract(buf, index); + index += 1 + buf.get(index); + + Verb verb = Verb.fromId(buf.getInt(index)); + index += 4; + + Map params = extractParams(buf, index, version); + int flags = removeFlagsFromLegacyParams(params); + + long createdAtNanos = calculateCreationTimeNanos(createdAtMillis, timeSnapshot, currentTimeNanos); + long expiresAtNanos = verb.expiresAtNanos(createdAtNanos); + + return new Header(id, verb, from, createdAtNanos, expiresAtNanos, flags, params); + } + + private void serializePre40(Message message, DataOutputPlus out, int version) throws IOException + { + if (message.isFailureResponse()) + message = toPre40FailureResponse(message); + + serializeHeaderPre40(message.header, out, version); + + if (message.payload != null && message.payload != NoPayload.noPayload) + { + int payloadSize = message.payloadSize(version); + out.writeInt(payloadSize); + message.verb().serializer().serialize(message.payload, out, version); + } + else + { + out.writeInt(0); + } + } + + private Message deserializePre40(DataInputPlus in, int version) throws IOException + { + Header header = deserializeHeaderPre40(in, version); + return deserializePre40(in, header, false, version); + } + + private Message deserializePre40(DataInputPlus in, Header header, int version) throws IOException + { + return deserializePre40(in, header, true, version); + } + + private Message deserializePre40(DataInputPlus in, Header header, boolean skipHeader, int version) throws IOException + { + if (skipHeader) + skipHeaderPre40(in); + + IVersionedAsymmetricSerializer payloadSerializer = header.verb.serializer(); + if (null == payloadSerializer) + payloadSerializer = instance().callbacks.responseSerializer(header.id, header.from); + int payloadSize = in.readInt(); + T payload = deserializePayloadPre40(in, version, payloadSerializer, payloadSize); + + Message message = new Message<>(header, payload); + + return header.params.containsKey(ParamType.FAILURE_RESPONSE) + ? (Message) toPost40FailureResponse(message) + : message; + } + + private T deserializePayloadPre40(DataInputPlus in, int version, IVersionedAsymmetricSerializer serializer, int payloadSize) throws IOException + { + if (payloadSize == 0 || serializer == null) + { + // if there's no deserializer for the verb, skip the payload bytes to leave + // the stream in a clean state (for the next message) + in.skipBytesFully(payloadSize); + return null; + } + + return serializer.deserialize(in, version); + } + + private int serializedSizePre40(Message message, int version) + { + if (message.isFailureResponse()) + message = toPre40FailureResponse(message); + + long size = 0; + size += serializedHeaderSizePre40(message.header, version); + int payloadSize = message.payloadSize(version); + size += sizeof(payloadSize); + size += payloadSize; + return Ints.checkedCast(size); + } + + private int inferMessageSizePre40(ByteBuffer buf, int readerIndex, int readerLimit) throws InvalidLegacyProtocolMagic + { + int index = readerIndex; + // protocol magic + index += 4; + if (index > readerLimit) + return -1; + validateLegacyProtocolMagic(buf.getInt(index - 4)); + + // rest of prefix + index += PRE_40_MESSAGE_PREFIX_SIZE - 4; + // ip address + index += 1; + if (index > readerLimit) + return -1; + index += buf.get(index - 1); + // verb + index += 4; + if (index > readerLimit) + return -1; + + int paramsSize = extractParamsSizePre40(buf, index, readerLimit); + if (paramsSize < 0) + return -1; + index += paramsSize; + + // payload + index += 4; + + if (index > readerLimit) + return -1; + index += buf.getInt(index - 4); + + return index - readerIndex; + } + + private Message toPre40FailureResponse(Message post40) + { + Map params = new EnumMap<>(ParamType.class); + params.putAll(post40.header.params); + + params.put(ParamType.FAILURE_RESPONSE, LegacyFlag.instance); + params.put(ParamType.FAILURE_REASON, post40.payload); + + Header header = new Header(post40.id(), post40.verb().toPre40Verb(), post40.from(), post40.createdAtNanos(), post40.expiresAtNanos(), 0, params); + return new Message<>(header, NoPayload.noPayload); + } + + private Message toPost40FailureResponse(Message pre40) + { + Map params = new EnumMap<>(ParamType.class); + params.putAll(pre40.header.params); + + params.remove(ParamType.FAILURE_RESPONSE); + + RequestFailureReason reason = (RequestFailureReason) params.remove(ParamType.FAILURE_REASON); + if (null == reason) + reason = RequestFailureReason.UNKNOWN; + + Header header = new Header(pre40.id(), Verb.FAILURE_RSP, pre40.from(), pre40.createdAtNanos(), pre40.expiresAtNanos(), pre40.header.flags, params); + return new Message<>(header, reason); + } + + /* + * created at + cross-node + */ + + private static final long TIMESTAMP_WRAPAROUND_GRACE_PERIOD_START = 0xFFFFFFFFL - MINUTES.toMillis(15L); + private static final long TIMESTAMP_WRAPAROUND_GRACE_PERIOD_END = MINUTES.toMillis(15L); + + private static long calculateCreationTimeNanos(int messageTimestampMillis, MonotonicClockTranslation timeSnapshot, long currentTimeNanos) + { + long currentTimeMillis = timeSnapshot.toMillisSinceEpoch(currentTimeNanos); + // Reconstruct the message construction time sent by the remote host (we sent only the lower 4 bytes, assuming the + // higher 4 bytes wouldn't change between the sender and receiver) + long highBits = currentTimeMillis & 0xFFFFFFFF00000000L; + + long sentLowBits = messageTimestampMillis & 0x00000000FFFFFFFFL; + long currentLowBits = currentTimeMillis & 0x00000000FFFFFFFFL; + + // if our sent bits occur within a grace period of a wrap around event, + // and our current bits are no more than the same grace period after a wrap around event, + // assume a wrap around has occurred, and deduct one highBit + if ( sentLowBits > TIMESTAMP_WRAPAROUND_GRACE_PERIOD_START + && currentLowBits < TIMESTAMP_WRAPAROUND_GRACE_PERIOD_END) + { + highBits -= 0x0000000100000000L; + } + + long sentTimeMillis = (highBits | sentLowBits); + return timeSnapshot.fromMillisSinceEpoch(sentTimeMillis); + } + + /* + * param ser/deser + */ + + private Map addFlagsToLegacyParams(Map params, int flags) + { + if (flags == 0) + return params; + + Map extended = new EnumMap<>(ParamType.class); + extended.putAll(params); + + if (MessageFlag.CALL_BACK_ON_FAILURE.isIn(flags)) + extended.put(ParamType.FAILURE_CALLBACK, LegacyFlag.instance); + + if (MessageFlag.TRACK_REPAIRED_DATA.isIn(flags)) + extended.put(ParamType.TRACK_REPAIRED_DATA, LegacyFlag.instance); + + return extended; + } + + private int removeFlagsFromLegacyParams(Map params) + { + int flags = 0; + + if (null != params.remove(ParamType.FAILURE_CALLBACK)) + flags = MessageFlag.CALL_BACK_ON_FAILURE.addTo(flags); + + if (null != params.remove(ParamType.TRACK_REPAIRED_DATA)) + flags = MessageFlag.TRACK_REPAIRED_DATA.addTo(flags); + + return flags; + } + + private void serializeParams(Map params, DataOutputPlus out, int version) throws IOException + { + if (version >= VERSION_40) + out.writeUnsignedVInt(params.size()); + else + out.writeInt(params.size()); + + for (Map.Entry kv : params.entrySet()) + { + ParamType type = kv.getKey(); + if (version >= VERSION_40) + out.writeUnsignedVInt(type.id); + else + out.writeUTF(type.legacyAlias); + + IVersionedSerializer serializer = type.serializer; + Object value = kv.getValue(); + + int length = Ints.checkedCast(serializer.serializedSize(value, version)); + if (version >= VERSION_40) + out.writeUnsignedVInt(length); + else + out.writeInt(length); + + serializer.serialize(value, out, version); + } + } + + private Map deserializeParams(DataInputPlus in, int version) throws IOException + { + int count = version >= VERSION_40 ? Ints.checkedCast(in.readUnsignedVInt()) : in.readInt(); + + if (count == 0) + return NO_PARAMS; + + Map params = new EnumMap<>(ParamType.class); + + for (int i = 0; i < count; i++) + { + ParamType type = version >= VERSION_40 + ? ParamType.lookUpById(Ints.checkedCast(in.readUnsignedVInt())) + : ParamType.lookUpByAlias(in.readUTF()); + + int length = version >= VERSION_40 + ? Ints.checkedCast(in.readUnsignedVInt()) + : in.readInt(); + + if (null != type) + params.put(type, type.serializer.deserialize(in, version)); + else + in.skipBytesFully(length); // forward compatibiliy with minor version changes + } + + return params; + } + + /* + * Extract post-4.0 params map from a ByteBuffer without modifying it. + */ + private Map extractParams(ByteBuffer buf, int readerIndex, int version) throws IOException + { + long count = version >= VERSION_40 ? getUnsignedVInt(buf, readerIndex) : buf.getInt(readerIndex); + + if (count == 0) + return NO_PARAMS; + + final int position = buf.position(); + buf.position(readerIndex); + + try (DataInputBuffer in = new DataInputBuffer(buf, false)) + { + return deserializeParams(in, version); + } + finally + { + buf.position(position); + } + } + + private void skipParamsPost40(DataInputPlus in) throws IOException + { + int count = Ints.checkedCast(in.readUnsignedVInt()); + + for (int i = 0; i < count; i++) + { + skipUnsignedVInt(in); + in.skipBytesFully(Ints.checkedCast(in.readUnsignedVInt())); + } + } + + private void skipParamsPre40(DataInputPlus in) throws IOException + { + int count = in.readInt(); + + for (int i = 0; i < count; i++) + { + in.skipBytesFully(in.readShort()); + in.skipBytesFully(in.readInt()); + } + } + + private long serializedParamsSize(Map params, int version) + { + long size = version >= VERSION_40 + ? computeUnsignedVIntSize(params.size()) + : sizeof(params.size()); + + for (Map.Entry kv : params.entrySet()) + { + ParamType type = kv.getKey(); + Object value = kv.getValue(); + + long valueLength = type.serializer.serializedSize(value, version); + + if (version >= VERSION_40) + size += sizeofUnsignedVInt(type.id) + sizeofUnsignedVInt(valueLength); + else + size += sizeof(type.legacyAlias) + 4; + + size += valueLength; + } + + return size; + } + + private int extractParamsSizePost40(ByteBuffer buf, int readerIndex, int readerLimit) + { + int index = readerIndex; + + long paramsCount = getUnsignedVInt(buf, index, readerLimit); + if (paramsCount < 0) + return -1; + index += computeUnsignedVIntSize(paramsCount); + + for (int i = 0; i < paramsCount; i++) + { + long type = getUnsignedVInt(buf, index, readerLimit); + if (type < 0) + return -1; + index += computeUnsignedVIntSize(type); + + long length = getUnsignedVInt(buf, index, readerLimit); + if (length < 0) + return -1; + index += computeUnsignedVIntSize(length) + length; + } + + return index - readerIndex; + } + + private int extractParamsSizePre40(ByteBuffer buf, int readerIndex, int readerLimit) + { + int index = readerIndex; + + index += 4; + if (index > readerLimit) + return -1; + int paramsCount = buf.getInt(index - 4); + + for (int i = 0; i < paramsCount; i++) + { + // try to read length and skip to the end of the param name + index += 2; + + if (index > readerLimit) + return -1; + index += buf.getShort(index - 2); + // try to read length and skip to the end of the param value + index += 4; + if (index > readerLimit) + return -1; + index += buf.getInt(index - 4); + } + + return index - readerIndex; + } + + private int payloadSize(Message message, int version) + { + long payloadSize = message.payload != null && message.payload != NoPayload.noPayload + ? message.verb().serializer().serializedSize(message.payload, version) + : 0; + return Ints.checkedCast(payloadSize); + } + } + + private int serializedSize30; + private int serializedSize3014; + private int serializedSize40; + + /** + * Serialized size of the entire message, for the provided messaging version. Caches the calculated value. + */ + public int serializedSize(int version) + { + switch (version) + { + case VERSION_30: + if (serializedSize30 == 0) + serializedSize30 = serializer.serializedSize(this, VERSION_30); + return serializedSize30; + case VERSION_3014: + if (serializedSize3014 == 0) + serializedSize3014 = serializer.serializedSize(this, VERSION_3014); + return serializedSize3014; + case VERSION_40: + if (serializedSize40 == 0) + serializedSize40 = serializer.serializedSize(this, VERSION_40); + return serializedSize40; + default: + throw new IllegalStateException(); + } + } + + private int payloadSize30 = -1; + private int payloadSize3014 = -1; + private int payloadSize40 = -1; + + private int payloadSize(int version) + { + switch (version) + { + case VERSION_30: + if (payloadSize30 < 0) + payloadSize30 = serializer.payloadSize(this, VERSION_30); + return payloadSize30; + case VERSION_3014: + if (payloadSize3014 < 0) + payloadSize3014 = serializer.payloadSize(this, VERSION_3014); + return payloadSize3014; + case VERSION_40: + if (payloadSize40 < 0) + payloadSize40 = serializer.payloadSize(this, VERSION_40); + return payloadSize40; + default: + throw new IllegalStateException(); + } + } + + static class OversizedMessageException extends RuntimeException + { + OversizedMessageException(int size) + { + super("Message of size " + size + " bytes exceeds allowed maximum of " + DatabaseDescriptor.getInternodeMaxMessageSizeInBytes() + " bytes"); + } + } +} diff --git a/src/java/org/apache/cassandra/net/MessageDeliveryTask.java b/src/java/org/apache/cassandra/net/MessageDeliveryTask.java deleted file mode 100644 index 1b9090c3081b..000000000000 --- a/src/java/org/apache/cassandra/net/MessageDeliveryTask.java +++ /dev/null @@ -1,134 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.cassandra.net; - -import java.io.IOException; -import java.util.EnumSet; - -import com.google.common.annotations.VisibleForTesting; -import com.google.common.primitives.Shorts; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import org.apache.cassandra.db.filter.TombstoneOverwhelmingException; -import org.apache.cassandra.db.monitoring.ApproximateTime; -import org.apache.cassandra.exceptions.RequestFailureReason; -import org.apache.cassandra.gms.Gossiper; -import org.apache.cassandra.index.IndexNotAvailableException; -import org.apache.cassandra.io.DummyByteVersionedSerializer; -import org.apache.cassandra.io.util.DataOutputBuffer; - -public class MessageDeliveryTask implements Runnable -{ - private static final Logger logger = LoggerFactory.getLogger(MessageDeliveryTask.class); - - private final MessageIn message; - private final int id; - private final long enqueueTime; - - public MessageDeliveryTask(MessageIn message, int id) - { - assert message != null; - this.message = message; - this.id = id; - this.enqueueTime = ApproximateTime.currentTimeMillis(); - } - - public void run() - { - process(); - } - - /** - * A helper function for making unit testing reasonable. - * - * @return true if the message was processed; else false. - */ - @VisibleForTesting - boolean process() - { - MessagingService.Verb verb = message.verb; - if (verb == null) - { - logger.trace("Unknown verb {}", verb); - return false; - } - - MessagingService.instance().metrics.addQueueWaitTime(verb.toString(), - ApproximateTime.currentTimeMillis() - enqueueTime); - - long timeTaken = message.getLifetimeInMS(); - if (MessagingService.DROPPABLE_VERBS.contains(verb) - && timeTaken > message.getTimeout()) - { - MessagingService.instance().incrementDroppedMessages(message, timeTaken); - return false; - } - - IVerbHandler verbHandler = MessagingService.instance().getVerbHandler(verb); - if (verbHandler == null) - { - logger.trace("No handler for verb {}", verb); - return false; - } - - try - { - verbHandler.doVerb(message, id); - } - catch (IOException ioe) - { - handleFailure(ioe); - throw new RuntimeException(ioe); - } - catch (TombstoneOverwhelmingException | IndexNotAvailableException e) - { - handleFailure(e); - logger.error(e.getMessage()); - } - catch (Throwable t) - { - handleFailure(t); - throw t; - } - - if (GOSSIP_VERBS.contains(message.verb)) - Gossiper.instance.setLastProcessedMessageAt(message.constructionTime); - return true; - } - - private void handleFailure(Throwable t) - { - if (message.doCallbackOnFailure()) - { - MessageOut response = new MessageOut(MessagingService.Verb.INTERNAL_RESPONSE) - .withParameter(ParameterType.FAILURE_RESPONSE, MessagingService.ONE_BYTE); - - if (t instanceof TombstoneOverwhelmingException) - { - response = response.withParameter(ParameterType.FAILURE_REASON, Shorts.checkedCast(RequestFailureReason.READ_TOO_MANY_TOMBSTONES.code)); - } - - MessagingService.instance().sendReply(response, id, message.from); - } - } - - private static final EnumSet GOSSIP_VERBS = EnumSet.of(MessagingService.Verb.GOSSIP_DIGEST_ACK, - MessagingService.Verb.GOSSIP_DIGEST_ACK2, - MessagingService.Verb.GOSSIP_DIGEST_SYN); -} diff --git a/src/java/org/apache/cassandra/net/MessageFlag.java b/src/java/org/apache/cassandra/net/MessageFlag.java new file mode 100644 index 000000000000..c74784d4caba --- /dev/null +++ b/src/java/org/apache/cassandra/net/MessageFlag.java @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.cassandra.net; + +import static java.lang.Math.max; + +/** + * Binary message flags to be passed as {@code flags} field of {@link Message}. + */ +public enum MessageFlag +{ + /** a failure response should be sent back in case of failure */ + CALL_BACK_ON_FAILURE (0), + /** track repaired data - see CASSANDRA-14145 */ + TRACK_REPAIRED_DATA (1); + + private final int id; + + MessageFlag(int id) + { + this.id = id; + } + + /** + * @return {@code true} if the flag is present in provided flags, {@code false} otherwise + */ + boolean isIn(int flags) + { + return (flags & (1 << id)) != 0; + } + + /** + * @return new flags value with this flag added + */ + int addTo(int flags) + { + return flags | (1 << id); + } + + private static final MessageFlag[] idToFlagMap; + static + { + MessageFlag[] flags = values(); + + int max = -1; + for (MessageFlag flag : flags) + max = max(flag.id, max); + + MessageFlag[] idMap = new MessageFlag[max + 1]; + for (MessageFlag flag : flags) + { + if (idMap[flag.id] != null) + throw new RuntimeException("Two MessageFlag-s that map to the same id: " + flag.id); + idMap[flag.id] = flag; + } + idToFlagMap = idMap; + } + + @SuppressWarnings("unused") + MessageFlag lookUpById(int id) + { + if (id < 0) + throw new IllegalArgumentException("MessageFlag id must be non-negative (got " + id + ')'); + + return id < idToFlagMap.length ? idToFlagMap[id] : null; + } +} + diff --git a/src/java/org/apache/cassandra/net/MessageIn.java b/src/java/org/apache/cassandra/net/MessageIn.java deleted file mode 100644 index c8f4bfcd2608..000000000000 --- a/src/java/org/apache/cassandra/net/MessageIn.java +++ /dev/null @@ -1,234 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.cassandra.net; - -import java.io.IOException; -import java.util.Collections; -import java.util.Map; - -import com.google.common.collect.ImmutableMap; - -import org.apache.cassandra.concurrent.Stage; -import org.apache.cassandra.config.DatabaseDescriptor; -import org.apache.cassandra.db.monitoring.ApproximateTime; -import org.apache.cassandra.exceptions.RequestFailureReason; -import org.apache.cassandra.io.IVersionedSerializer; -import org.apache.cassandra.io.util.DataInputBuffer; -import org.apache.cassandra.io.util.DataInputPlus; -import org.apache.cassandra.locator.InetAddressAndPort; -import org.apache.cassandra.net.MessagingService.Verb; -import org.apache.cassandra.utils.FBUtilities; - -/** - * The receiving node's view of a {@link MessageOut}. See documentation on {@link MessageOut} for details on the - * serialization format. - * - * @param The type of the payload - */ -public class MessageIn -{ - public final InetAddressAndPort from; - public final T payload; - public final Map parameters; - public final MessagingService.Verb verb; - public final int version; - public final long constructionTime; - - public MessageIn(InetAddressAndPort from, - T payload, - Map parameters, - Verb verb, - int version, - long constructionTime) - { - this.from = from; - this.payload = payload; - this.parameters = parameters; - this.verb = verb; - this.version = version; - this.constructionTime = constructionTime; - } - - public static MessageIn create(InetAddressAndPort from, - T payload, - Map parameters, - Verb verb, - int version, - long constructionTime) - { - return new MessageIn<>(from, payload, parameters, verb, version, constructionTime); - } - - public static MessageIn create(InetAddressAndPort from, - T payload, - Map parameters, - MessagingService.Verb verb, - int version) - { - return new MessageIn<>(from, payload, parameters, verb, version, ApproximateTime.currentTimeMillis()); - } - - public static MessageIn read(DataInputPlus in, int version, int id) throws IOException - { - return read(in, version, id, ApproximateTime.currentTimeMillis()); - } - - public static MessageIn read(DataInputPlus in, int version, int id, long constructionTime) throws IOException - { - InetAddressAndPort from = CompactEndpointSerializationHelper.instance.deserialize(in, version); - - MessagingService.Verb verb = MessagingService.Verb.fromId(in.readInt()); - Map parameters = readParameters(in, version); - int payloadSize = in.readInt(); - return read(in, version, id, constructionTime, from, payloadSize, verb, parameters); - } - - public static Map readParameters(DataInputPlus in, int version) throws IOException - { - int parameterCount = in.readInt(); - Map parameters; - if (parameterCount == 0) - { - return Collections.emptyMap(); - } - else - { - ImmutableMap.Builder builder = ImmutableMap.builder(); - for (int i = 0; i < parameterCount; i++) - { - String key = in.readUTF(); - ParameterType type = ParameterType.byName.get(key); - if (type != null) - { - byte[] value = new byte[in.readInt()]; - in.readFully(value); - try (DataInputBuffer buffer = new DataInputBuffer(value)) - { - builder.put(type, type.serializer.deserialize(buffer, version)); - } - } - else - { - in.skipBytes(in.readInt()); - } - } - return builder.build(); - } - } - - public static MessageIn read(DataInputPlus in, int version, int id, long constructionTime, - InetAddressAndPort from, int payloadSize, Verb verb, Map parameters) throws IOException - { - IVersionedSerializer serializer = (IVersionedSerializer) MessagingService.verbSerializers.get(verb); - if (serializer instanceof MessagingService.CallbackDeterminedSerializer) - { - CallbackInfo callback = MessagingService.instance().getRegisteredCallback(id); - if (callback == null) - { - // reply for expired callback. we'll have to skip it. - in.skipBytesFully(payloadSize); - return null; - } - serializer = (IVersionedSerializer) callback.serializer; - } - - if (payloadSize == 0 || serializer == null) - { - // if there's no deserializer for the verb, skip the payload bytes to leave - // the stream in a clean state (for the next message) - in.skipBytesFully(payloadSize); - return create(from, null, parameters, verb, version, constructionTime); - } - - T2 payload = serializer.deserialize(in, version); - return MessageIn.create(from, payload, parameters, verb, version, constructionTime); - } - - public static long deriveConstructionTime(InetAddressAndPort from, int messageTimestamp, long currentTime) - { - // Reconstruct the message construction time sent by the remote host (we sent only the lower 4 bytes, assuming the - // higher 4 bytes wouldn't change between the sender and receiver) - long sentConstructionTime = (currentTime & 0xFFFFFFFF00000000L) | (((messageTimestamp & 0xFFFFFFFFL) << 2) >> 2); - - // Because nodes may not have their clock perfectly in sync, it's actually possible the sentConstructionTime is - // later than the currentTime (the received time). If that's the case, as we definitively know there is a lack - // of proper synchronziation of the clock, we ignore sentConstructionTime. We also ignore that - // sentConstructionTime if we're told to. - long elapsed = currentTime - sentConstructionTime; - if (elapsed > 0) - MessagingService.instance().metrics.addTimeTaken(from, elapsed); - - boolean useSentTime = DatabaseDescriptor.hasCrossNodeTimeout() && elapsed > 0; - return useSentTime ? sentConstructionTime : currentTime; - } - - /** - * Since how long (in milliseconds) the message has lived. - */ - public long getLifetimeInMS() - { - return ApproximateTime.currentTimeMillis() - constructionTime; - } - - /** - * Whether the message has crossed the node boundary, that is whether it originated from another node. - * - */ - public boolean isCrossNode() - { - return !from.equals(FBUtilities.getBroadcastAddressAndPort()); - } - - public Stage getMessageType() - { - return MessagingService.verbStages.get(verb); - } - - public boolean doCallbackOnFailure() - { - return parameters.containsKey(ParameterType.FAILURE_CALLBACK); - } - - public boolean isFailureResponse() - { - return parameters.containsKey(ParameterType.FAILURE_RESPONSE); - } - - public RequestFailureReason getFailureReason() - { - Short code = (Short)parameters.get(ParameterType.FAILURE_REASON); - return code != null ? RequestFailureReason.fromCode(code) : RequestFailureReason.UNKNOWN; - } - - public long getTimeout() - { - return verb.getTimeout(); - } - - public long getSlowQueryTimeout() - { - return DatabaseDescriptor.getSlowQueryTimeout(); - } - - public String toString() - { - StringBuilder sbuf = new StringBuilder(); - sbuf.append("FROM:").append(from).append(" TYPE:").append(getMessageType()).append(" VERB:").append(verb); - return sbuf.toString(); - } -} diff --git a/src/java/org/apache/cassandra/net/MessageOut.java b/src/java/org/apache/cassandra/net/MessageOut.java deleted file mode 100644 index 834435e3bd73..000000000000 --- a/src/java/org/apache/cassandra/net/MessageOut.java +++ /dev/null @@ -1,406 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.cassandra.net; - -import java.io.IOError; -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; - -import com.google.common.annotations.VisibleForTesting; -import com.google.common.collect.ImmutableList; -import com.google.common.primitives.Ints; - -import org.apache.cassandra.concurrent.Stage; -import org.apache.cassandra.db.TypeSizes; -import org.apache.cassandra.io.IVersionedSerializer; -import org.apache.cassandra.io.util.DataOutputBuffer; -import org.apache.cassandra.io.util.DataOutputPlus; -import org.apache.cassandra.locator.InetAddressAndPort; -import org.apache.cassandra.net.async.OutboundConnectionIdentifier.ConnectionType; -import org.apache.cassandra.tracing.Tracing; -import org.apache.cassandra.utils.FBUtilities; -import org.apache.cassandra.utils.vint.VIntCoding; - -import static org.apache.cassandra.tracing.Tracing.isTracing; - -/** - * Each message contains a header with several fixed fields, an optional key-value parameters section, and then - * the message payload itself. Note: the legacy IP address (pre-4.0) in the header may be either IPv4 (4 bytes) - * or IPv6 (16 bytes). The diagram below shows the IPv4 address for brevity. In pre-4.0, the payloadSize was - * encoded as a 4-byte integer; in 4.0 and up it is an unsigned byte (255 parameters should be enough for anyone). - * - *
- * {@code
- *            1 1 1 1 1 2 2 2 2 2 3
- *  0 2 4 6 8 0 2 4 6 8 0 2 4 6 8 0
- * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
- * |       PROTOCOL MAGIC          |
- * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
- * |        Message ID             |
- * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
- * |        Timestamp              |
- * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
- * |          Verb                 |
- * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
- * |ParmLen| Parameter data (var)  |
- * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
- * |   Payload size (vint)         |
- * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
- * |                               /
- * /           Payload             /
- * /                               |
- * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
- * }
- * 
- * - * An individual parameter has a String key and a byte array value. The key is serialized with it's length, - * encoded as two bytes, followed by the UTF-8 byte encoding of the string (see {@link java.io.DataOutput#writeUTF(java.lang.String)}). - * The body is serialized with it's length, encoded as four bytes, followed by the bytes of the value. - * - * * @param The type of the message payload. - */ -public class MessageOut -{ - private static final int SERIALIZED_SIZE_VERSION_UNDEFINED = -1; - //Parameters are stored in an object array as tuples of size two - public static final int PARAMETER_TUPLE_SIZE = 2; - //Offset in a parameter tuple containing the type of the parameter - public static final int PARAMETER_TUPLE_TYPE_OFFSET = 0; - //Offset in a parameter tuple containing the actual parameter represented as a POJO - public static final int PARAMETER_TUPLE_PARAMETER_OFFSET = 1; - - public final InetAddressAndPort from; - public final MessagingService.Verb verb; - public final T payload; - public final IVersionedSerializer serializer; - //A list of tuples, first object is the ParameterType enum value, - //the second object is the POJO to serialize - public final List parameters; - - /** - * Allows sender to explicitly state which connection type the message should be sent on. - */ - public final ConnectionType connectionType; - - /** - * Memoization of the serialized size of the just the payload. - */ - private int payloadSerializedSize = -1; - - /** - * Memoization of the serialized size of the entire message. - */ - private int serializedSize = -1; - - /** - * The internode protocol messaging version that was used to calculate the memoized serailized sizes. - */ - private int serializedSizeVersion = SERIALIZED_SIZE_VERSION_UNDEFINED; - - // we do support messages that just consist of a verb - public MessageOut(MessagingService.Verb verb) - { - this(verb, null, null); - } - - public MessageOut(MessagingService.Verb verb, T payload, IVersionedSerializer serializer) - { - this(verb, - payload, - serializer, - isTracing() ? Tracing.instance.getTraceHeaders() : ImmutableList.of(), - null); - } - - public MessageOut(MessagingService.Verb verb, T payload, IVersionedSerializer serializer, ConnectionType connectionType) - { - this(verb, - payload, - serializer, - isTracing() ? Tracing.instance.getTraceHeaders() : ImmutableList.of(), - connectionType); - } - - private MessageOut(MessagingService.Verb verb, T payload, IVersionedSerializer serializer, List parameters, ConnectionType connectionType) - { - this(FBUtilities.getBroadcastAddressAndPort(), verb, payload, serializer, parameters, connectionType); - } - - @VisibleForTesting - public MessageOut(InetAddressAndPort from, MessagingService.Verb verb, T payload, IVersionedSerializer serializer, List parameters, ConnectionType connectionType) - { - this.from = from; - this.verb = verb; - this.payload = payload; - this.serializer = serializer; - this.parameters = parameters; - this.connectionType = connectionType; - } - - public MessageOut withParameter(ParameterType type, VT value) - { - List newParameters = new ArrayList<>(parameters.size() + 2); - newParameters.addAll(parameters); - newParameters.add(type); - newParameters.add(value); - return new MessageOut(from, verb, payload, serializer, newParameters, connectionType); - } - - public Stage getStage() - { - return MessagingService.verbStages.get(verb); - } - - public long getTimeout() - { - return verb.getTimeout(); - } - - public String toString() - { - StringBuilder sbuf = new StringBuilder(); - sbuf.append("TYPE:").append(getStage()).append(" VERB:").append(verb); - return sbuf.toString(); - } - - public void serialize(DataOutputPlus out, int version) throws IOException - { - if (version >= MessagingService.VERSION_40) - serialize40(out, version); - else - serializePre40(out, version); - } - - private void serialize40(DataOutputPlus out, int version) throws IOException - { - out.writeInt(verb.getId()); - - // serialize the headers, if any - assert parameters.size() % PARAMETER_TUPLE_SIZE == 0; - if (parameters.isEmpty()) - { - out.writeVInt(0); - } - else - { - try (DataOutputBuffer buf = new DataOutputBuffer()) - { - serializeParams(buf, version); - out.writeUnsignedVInt(buf.getLength()); - out.write(buf.buffer()); - } - } - - if (payload != null) - { - int payloadSize = payloadSerializedSize >= 0 - ? payloadSerializedSize - : (int) serializer.serializedSize(payload, version); - - out.writeUnsignedVInt(payloadSize); - serializer.serialize(payload, out, version); - } - else - { - out.writeUnsignedVInt(0); - } - } - - private void serializePre40(DataOutputPlus out, int version) throws IOException - { - CompactEndpointSerializationHelper.instance.serialize(from, out, version); - out.writeInt(verb.getId()); - - assert parameters.size() % PARAMETER_TUPLE_SIZE == 0; - out.writeInt(parameters.size() / PARAMETER_TUPLE_SIZE); - serializeParams(out, version); - - if (payload != null) - { - int payloadSize = payloadSerializedSize >= 0 - ? payloadSerializedSize - : (int) serializer.serializedSize(payload, version); - - out.writeInt(payloadSize); - serializer.serialize(payload, out, version); - } - else - { - out.writeInt(0); - } - } - - private void serializeParams(DataOutputPlus out, int version) throws IOException - { - for (int ii = 0; ii < parameters.size(); ii += PARAMETER_TUPLE_SIZE) - { - ParameterType type = (ParameterType)parameters.get(ii + PARAMETER_TUPLE_TYPE_OFFSET); - out.writeUTF(type.key); - IVersionedSerializer serializer = type.serializer; - Object parameter = parameters.get(ii + PARAMETER_TUPLE_PARAMETER_OFFSET); - - int valueLength = Ints.checkedCast(serializer.serializedSize(parameter, version)); - if (version >= MessagingService.VERSION_40) - out.writeUnsignedVInt(valueLength); - else - out.writeInt(valueLength); - - serializer.serialize(parameter, out, version); - } - } - - private MessageOutSizes calculateSerializedSize(int version) - { - return version >= MessagingService.VERSION_40 - ? calculateSerializedSize40(version) - : calculateSerializedSizePre40(version); - } - - private MessageOutSizes calculateSerializedSize40(int version) - { - long size = 0; - size += TypeSizes.sizeof(verb.getId()); - - if (parameters.isEmpty()) - { - size += VIntCoding.computeVIntSize(0); - } - else - { - // calculate the params size independently, as we write that before the actual params block - int paramsSize = 0; - for (int ii = 0; ii < parameters.size(); ii += PARAMETER_TUPLE_SIZE) - { - ParameterType type = (ParameterType)parameters.get(ii + PARAMETER_TUPLE_TYPE_OFFSET); - paramsSize += TypeSizes.sizeof(type.key()); - IVersionedSerializer serializer = type.serializer; - Object parameter = parameters.get(ii + PARAMETER_TUPLE_PARAMETER_OFFSET); - int valueLength = Ints.checkedCast(serializer.serializedSize(parameter, version)); - paramsSize += VIntCoding.computeUnsignedVIntSize(valueLength);//length prefix - paramsSize += valueLength; - } - size += VIntCoding.computeUnsignedVIntSize(paramsSize); - size += paramsSize; - } - - long payloadSize = payload == null ? 0 : serializer.serializedSize(payload, version); - assert payloadSize <= Integer.MAX_VALUE; // larger values are supported in sstables but not messages - size += VIntCoding.computeUnsignedVIntSize(payloadSize); - size += payloadSize; - return new MessageOutSizes(size, payloadSize); - } - - private MessageOutSizes calculateSerializedSizePre40(int version) - { - long size = 0; - size += CompactEndpointSerializationHelper.instance.serializedSize(from, version); - - size += TypeSizes.sizeof(verb.getId()); - size += TypeSizes.sizeof(parameters.size() / PARAMETER_TUPLE_SIZE); - for (int ii = 0; ii < parameters.size(); ii += PARAMETER_TUPLE_SIZE) - { - ParameterType type = (ParameterType)parameters.get(ii + PARAMETER_TUPLE_TYPE_OFFSET); - size += TypeSizes.sizeof(type.key()); - size += 4;//length prefix - IVersionedSerializer serializer = type.serializer; - Object parameter = parameters.get(ii + PARAMETER_TUPLE_PARAMETER_OFFSET); - size += serializer.serializedSize(parameter, version); - } - - long payloadSize = payload == null ? 0 : serializer.serializedSize(payload, version); - assert payloadSize <= Integer.MAX_VALUE; // larger values are supported in sstables but not messages - size += TypeSizes.sizeof((int) payloadSize); - size += payloadSize; - return new MessageOutSizes(size, payloadSize); - } - - /** - * Calculate the size of this message for the specified protocol version and memoize the result for the specified - * protocol version. Memoization only covers the protocol version of the first invocation. - * - * It is not safe to call this function concurrently from multiple threads unless it has already been invoked - * once from a single thread and there is a happens before relationship between that invocation and other - * threads concurrently invoking this function. - * - * For instance it would be safe to invokePayload size to make a decision in the thread that created the message - * and then hand it off to other threads via a thread-safe queue, volatile write, or synchronized/ReentrantLock. - * - * @param version Protocol version to use when calculating size - * @return Size of this message in bytes, which will be less than or equal to {@link Integer#MAX_VALUE} - */ - public int serializedSize(int version) - { - if (serializedSize > 0 && serializedSizeVersion == version) - return serializedSize; - - MessageOutSizes sizes = calculateSerializedSize(version); - if (sizes.messageSize > Integer.MAX_VALUE) - throw new IllegalStateException("message size exceeds maximum allowed size: size = " + sizes.messageSize); - - if (serializedSizeVersion == SERIALIZED_SIZE_VERSION_UNDEFINED) - { - serializedSize = Ints.checkedCast(sizes.messageSize); - payloadSerializedSize = Ints.checkedCast(sizes.payloadSize); - serializedSizeVersion = version; - } - - return Ints.checkedCast(sizes.messageSize); - } - - public Object getParameter(ParameterType type) - { - for (int ii = 0; ii < parameters.size(); ii += PARAMETER_TUPLE_SIZE) - { - if (((ParameterType)parameters.get(ii + PARAMETER_TUPLE_TYPE_OFFSET)).equals(type)) - { - return parameters.get(ii + PARAMETER_TUPLE_PARAMETER_OFFSET); - } - } - return null; - } - - private static class MessageOutSizes - { - public final long messageSize; - public final long payloadSize; - - private MessageOutSizes(long messageSize, long payloadSize) - { - this.messageSize = messageSize; - this.payloadSize = payloadSize; - } - - @Override - public final int hashCode() - { - int hashCode = (int) messageSize ^ (int) (messageSize >>> 32); - return 31 * (hashCode ^ (int) ((int) payloadSize ^ (payloadSize >>> 32))); - } - - @Override - public final boolean equals(Object o) - { - if (!(o instanceof MessageOutSizes)) - return false; - MessageOutSizes that = (MessageOutSizes) o; - return messageSize == that.messageSize && payloadSize == that.payloadSize; - } - } -} diff --git a/src/java/org/apache/cassandra/net/MessagingService.java b/src/java/org/apache/cassandra/net/MessagingService.java index f72cd619d855..8b5ab1acdda4 100644 --- a/src/java/org/apache/cassandra/net/MessagingService.java +++ b/src/java/org/apache/cassandra/net/MessagingService.java @@ -1,4 +1,4 @@ - /* +/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information @@ -17,109 +17,192 @@ */ package org.apache.cassandra.net; -import java.io.IOError; -import java.io.IOException; -import java.net.UnknownHostException; +import java.nio.channels.ClosedChannelException; import java.util.ArrayList; import java.util.Collection; -import java.util.EnumMap; -import java.util.EnumSet; -import java.util.HashMap; +import java.util.Collections; import java.util.HashSet; import java.util.List; -import java.util.Map; -import java.util.Optional; import java.util.Set; -import java.util.concurrent.ConcurrentMap; -import java.util.concurrent.CopyOnWriteArraySet; +import java.util.concurrent.ExecutorService; import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.TimeoutException; import com.google.common.annotations.VisibleForTesting; -import com.google.common.base.Function; -import com.google.common.collect.Lists; -import org.cliffc.high_scale_lib.NonBlockingHashMap; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import com.carrotsearch.hppc.IntObjectMap; -import com.carrotsearch.hppc.IntObjectOpenHashMap; -import io.netty.channel.Channel; -import io.netty.channel.group.ChannelGroup; -import io.netty.channel.group.DefaultChannelGroup; -import org.apache.cassandra.auth.IInternodeAuthenticator; -import org.apache.cassandra.batchlog.Batch; -import org.apache.cassandra.concurrent.ExecutorLocals; -import org.apache.cassandra.concurrent.LocalAwareExecutorService; +import io.netty.util.concurrent.Future; import org.apache.cassandra.concurrent.ScheduledExecutors; import org.apache.cassandra.concurrent.Stage; import org.apache.cassandra.concurrent.StageManager; import org.apache.cassandra.config.DatabaseDescriptor; -import org.apache.cassandra.config.EncryptionOptions; -import org.apache.cassandra.config.EncryptionOptions.ServerEncryptionOptions; -import org.apache.cassandra.db.ColumnFamilyStore; -import org.apache.cassandra.db.ConsistencyLevel; -import org.apache.cassandra.db.CounterMutation; -import org.apache.cassandra.db.IMutation; -import org.apache.cassandra.db.Keyspace; -import org.apache.cassandra.db.Mutation; -import org.apache.cassandra.db.ReadCommand; -import org.apache.cassandra.db.ReadResponse; -import org.apache.cassandra.db.SnapshotCommand; import org.apache.cassandra.db.SystemKeyspace; -import org.apache.cassandra.db.TruncateResponse; -import org.apache.cassandra.db.Truncation; -import org.apache.cassandra.db.WriteResponse; -import org.apache.cassandra.dht.AbstractBounds; -import org.apache.cassandra.dht.BootStrapper; -import org.apache.cassandra.dht.IPartitioner; -import org.apache.cassandra.exceptions.ConfigurationException; import org.apache.cassandra.exceptions.RequestFailureReason; -import org.apache.cassandra.gms.EchoMessage; -import org.apache.cassandra.gms.GossipDigestAck; -import org.apache.cassandra.gms.GossipDigestAck2; -import org.apache.cassandra.gms.GossipDigestSyn; -import org.apache.cassandra.hints.HintMessage; -import org.apache.cassandra.hints.HintResponse; -import org.apache.cassandra.io.IVersionedSerializer; -import org.apache.cassandra.io.util.DataInputPlus; -import org.apache.cassandra.io.util.DataOutputPlus; -import org.apache.cassandra.locator.IEndpointSnitch; -import org.apache.cassandra.locator.ILatencySubscriber; import org.apache.cassandra.locator.InetAddressAndPort; import org.apache.cassandra.locator.Replica; -import org.apache.cassandra.metrics.CassandraMetricsRegistry; -import org.apache.cassandra.metrics.ConnectionMetrics; -import org.apache.cassandra.metrics.DroppedMessageMetrics; -import org.apache.cassandra.metrics.MessagingMetrics; -import org.apache.cassandra.net.async.OutboundMessagingPool; -import org.apache.cassandra.net.async.NettyFactory; -import org.apache.cassandra.net.async.NettyFactory.InboundInitializer; -import org.apache.cassandra.repair.messages.RepairMessage; -import org.apache.cassandra.schema.MigrationManager; -import org.apache.cassandra.schema.TableId; -import org.apache.cassandra.security.SSLFactory; import org.apache.cassandra.service.AbstractWriteResponseHandler; -import org.apache.cassandra.service.StorageProxy; -import org.apache.cassandra.service.StorageService; -import org.apache.cassandra.service.paxos.Commit; -import org.apache.cassandra.service.paxos.PrepareResponse; -import org.apache.cassandra.tracing.TraceState; -import org.apache.cassandra.tracing.Tracing; -import org.apache.cassandra.utils.BooleanSerializer; -import org.apache.cassandra.utils.ExpiringMap; +import org.apache.cassandra.utils.ExecutorUtils; import org.apache.cassandra.utils.FBUtilities; -import org.apache.cassandra.utils.MBeanWrapper; -import org.apache.cassandra.utils.NativeLibrary; -import org.apache.cassandra.utils.Pair; -import org.apache.cassandra.utils.StatusLogger; -import org.apache.cassandra.utils.UUIDSerializer; -import org.apache.cassandra.utils.concurrent.SimpleCondition; -public final class MessagingService implements MessagingServiceMBean +import static java.util.Collections.synchronizedList; +import static java.util.concurrent.TimeUnit.MINUTES; +import static java.util.concurrent.TimeUnit.NANOSECONDS; +import static org.apache.cassandra.concurrent.Stage.MUTATION; +import static org.apache.cassandra.utils.Throwables.maybeFail; + +/** + * MessagingService implements all internode communication - with the exception of SSTable streaming (for now). + * + * Specifically, it's responsible for dispatch of outbound messages to other nodes and routing of inbound messages + * to their appropriate {@link IVerbHandler}. + * + *

Using MessagingService: sending requests and responses

+ * + * The are two ways to send a {@link Message}, and you should pick one depending on the desired behaviour: + * 1. To send a request that expects a response back, use + * {@link #sendWithCallback(Message, InetAddressAndPort, RequestCallback)} method. Once a response + * message is received, {@link RequestCallback#onResponse(Message)} method will be invoked on the + * provided callback - in case of a success response. In case of a failure response (see {@link Verb#FAILURE_RSP}), + * or if a response doesn't arrive within verb's configured expiry time, + * {@link RequestCallback#onFailure(InetAddressAndPort, RequestFailureReason)} will be invoked instead. + * 2. To send a response back, or a message that expects no response, use {@link #send(Message, InetAddressAndPort)} + * method. + * + * See also: {@link Message#out(Verb, Object)}, {@link Message#responseWith(Object)}, + * and {@link Message#failureResponse(RequestFailureReason)}. + * + *

Using MessagingService: handling a request

+ * + * As described in the previous section, to handle responses you only need to implement {@link RequestCallback} + * interface - so long as your response verb handler is the default {@link ResponseVerbHandler}. + * + * There are two steps you need to perform to implement request handling: + * 1. Create a {@link IVerbHandler} to process incoming requests and responses for the new type (if applicable). + * 2. Add a new {@link Verb} to the enum for the new request type, and, if applicable, one for the response message. + * + * MessagingService will now automatically invoke your handler whenever a {@link Message} with this verb arrives. + * + *

Architecture of MessagingService

+ * + *

QOS

+ * + * Since our messaging protocol is TCP-based, and also doesn't yet support interleaving messages with each other, + * we need a way to prevent head-of-line blocking adversely affecting all messages - in particular, large messages + * being in the way of smaller ones. To achive that (somewhat), we maintain three messaging connections to and + * from each peer: + * - one for large messages - defined as being larger than {@link OutboundConnections#LARGE_MESSAGE_THRESHOLD} + * (65KiB by default) + * - one for small messages - defined as smaller than that threshold + * - and finally, a connection for urgent messages - usually small and/or that are important to arrive + * promptly, e.g. gossip-related ones + * + *

Wire format and framing

+ * + * Small messages are grouped together into frames, and large messages are split over multiple frames. + * Framing provides application-level integrity protection to otherwise raw streams of data - we use + * CRC24 for frame headers and CRC32 for the entire payload. LZ4 is optionally used for compression. + * + * You can find the on-wire format description of individual messages in the comments for + * {@link Message.Serializer}, alongside with format evolution notes. + * For the list and descriptions of available frame decoders see {@link FrameDecoder} comments. You can + * find wire format documented in the javadoc of {@link FrameDecoder} implementations: + * see {@link FrameDecoderCrc} and {@link FrameDecoderLZ4} in particular. + * + *

Architecture of outbound messaging

+ * + * {@link OutboundConnection} is the core class implementing outbound connection logic, with + * {@link OutboundConnection#enqueue(Message)} being its main entry point. The connections are initiated + * by {@link OutboundConnectionInitiator}. + * + * Netty pipeline for outbound messaging connections generally consists of the following handlers: + * + * [(optional) SslHandler] <- [FrameEncoder] + * + * {@link OutboundConnection} handles the entire lifetime of a connection: from the very first handshake + * to any necessary reconnects if necessary. + * + * Message-delivery flow varies depending on the connection type. + * + * For {@link ConnectionType#SMALL_MESSAGES} and {@link ConnectionType#URGENT_MESSAGES}, + * {@link Message} serialization and delivery occurs directly on the event loop. + * See {@link OutboundConnection.EventLoopDelivery} for details. + * + * For {@link ConnectionType#LARGE_MESSAGES}, to ensure that servicing large messages doesn't block + * timely service of other requests, message serialization is offloaded to a companion thread pool + * ({@link SocketFactory#synchronousWorkExecutor}). Most of the work will be performed by + * {@link AsyncChannelOutputPlus}. Please see {@link OutboundConnection.LargeMessageDelivery} + * for details. + * + * To prevent fast clients, or slow nodes on the other end of the connection from overwhelming + * a host with enqueued, unsent messages on heap, we impose strict limits on how much memory enqueued, + * undelivered messages can claim. + * + * Every individual connection gets an exclusive permit quota to use - 4MiB by default; every endpoint + * (group of large, small, and urgent connection) is capped at, by default, at 128MiB of undelivered messages, + * and a global limit of 512MiB is imposed on all endpoints combined. + * + * On an attempt to {@link OutboundConnection#enqueue(Message)}, the connection will attempt to allocate + * permits for message-size number of bytes from its exclusive quota; if successful, it will add the + * message to the queue; if unsuccessful, it will need to allocate remainder from both endpoint and lobal + * reserves, and if it fails to do so, the message will be rejected, and its callbacks, if any, + * immediately expired. + * + * For a more detailed description please see the docs and comments of {@link OutboundConnection}. + * + *

Architecture of inbound messaging

+ * + * {@link InboundMessageHandler} is the core class implementing inbound connection logic, paired + * with {@link FrameDecoder}. Inbound connections are initiated by {@link InboundConnectionInitiator}. + * The primary entry points to these classes are {@link FrameDecoder#channelRead(ShareableBytes)} + * and {@link InboundMessageHandler#process(FrameDecoder.Frame)}. + * + * Netty pipeline for inbound messaging connections generally consists of the following handlers: + * + * [(optional) SslHandler] -> [FrameDecoder] -> [InboundMessageHandler] + * + * {@link FrameDecoder} is responsible for decoding incoming frames and work stashing; {@link InboundMessageHandler} + * then takes decoded frames from the decoder and processes the messages contained in them. + * + * The flow differs between small and large messages. Small ones are deserialized immediately, and only + * then scheduled on the right thread pool for the {@link Verb} for execution. Large messages, OTOH, + * aren't deserialized until they are just about to be executed on the appropriate {@link Stage}. + * + * Similarly to outbound handling, inbound messaging imposes strict memory utilisation limits on individual + * endpoints and on global aggregate consumption, and implements simple flow control, to prevent a single + * fast endpoint from overwhelming a host. + * + * Every individual connection gets an exclusive permit quota to use - 4MiB by default; every endpoint + * (group of large, small, and urgent connection) is capped at, by default, at 128MiB of unprocessed messages, + * and a global limit of 512MiB is imposed on all endpoints combined. + * + * On arrival of a message header, the handler will attempt to allocate permits for message-size number + * of bytes from its exclusive quota; if successful, it will proceed to deserializing and processing the message. + * If unsuccessful, the handler will attempt to allocate the remainder from its endpoint and global reserve; + * if either allocation is unsuccessful, the handler will cease any further frame processing, and tell + * {@link FrameDecoder} to stop reading from the network; subsequently, it will put itself on a special + * {@link org.apache.cassandra.net.InboundMessageHandler.WaitQueue}, to be reactivated once more permits + * become available. + * + * For a more detailed description please see the docs and comments of {@link InboundMessageHandler} and + * {@link FrameDecoder}. + * + *

Observability

+ * + * MessagingService exposes diagnostic counters for both outbound and inbound directions - received and sent + * bytes and message counts, overload bytes and message count, error bytes and error counts, and many more. + * + * See {@link org.apache.cassandra.metrics.InternodeInboundMetrics} and + * {@link org.apache.cassandra.metrics.InternodeOutboundMetrics} for JMX-exposed counters. + * + * We also provide {@code system_views.internode_inbound} and {@code system_views.internode_outbound} virtual tables - + * implemented in {@link org.apache.cassandra.db.virtual.InternodeInboundTable} and + * {@link org.apache.cassandra.db.virtual.InternodeOutboundTable} respectively. + */ +public final class MessagingService extends MessagingServiceMBeanImpl { - public static final String MBEAN_NAME = "org.apache.cassandra.net:type=MessagingService"; + private static final Logger logger = LoggerFactory.getLogger(MessagingService.class); // 8 bits version, so don't waste versions public static final int VERSION_30 = 10; @@ -127,529 +210,142 @@ public final class MessagingService implements MessagingServiceMBean public static final int VERSION_40 = 12; public static final int minimum_version = VERSION_30; public static final int current_version = VERSION_40; + static AcceptVersions accept_messaging = new AcceptVersions(minimum_version, current_version); + static AcceptVersions accept_streaming = new AcceptVersions(current_version, current_version); - public static final byte[] ONE_BYTE = new byte[1]; - - /** - * we preface every message with this number so the recipient can validate the sender is sane - */ - public static final int PROTOCOL_MAGIC = 0xCA552DFA; - - public final MessagingMetrics metrics = new MessagingMetrics(); - - /* All verb handler identifiers */ - public enum Verb + private static class MSHandle { - MUTATION - { - public long getTimeout() - { - return DatabaseDescriptor.getWriteRpcTimeout(); - } - }, - HINT - { - public long getTimeout() - { - return DatabaseDescriptor.getWriteRpcTimeout(); - } - }, - READ_REPAIR - { - public long getTimeout() - { - return DatabaseDescriptor.getWriteRpcTimeout(); - } - }, - READ - { - public long getTimeout() - { - return DatabaseDescriptor.getReadRpcTimeout(); - } - }, - REQUEST_RESPONSE, // client-initiated reads and writes - BATCH_STORE - { - public long getTimeout() - { - return DatabaseDescriptor.getWriteRpcTimeout(); - } - }, // was @Deprecated STREAM_INITIATE, - BATCH_REMOVE - { - public long getTimeout() - { - return DatabaseDescriptor.getWriteRpcTimeout(); - } - }, // was @Deprecated STREAM_INITIATE_DONE, - @Deprecated STREAM_REPLY, - @Deprecated STREAM_REQUEST, - RANGE_SLICE - { - public long getTimeout() - { - return DatabaseDescriptor.getRangeRpcTimeout(); - } - }, - @Deprecated BOOTSTRAP_TOKEN, - @Deprecated TREE_REQUEST, - @Deprecated TREE_RESPONSE, - @Deprecated JOIN, - GOSSIP_DIGEST_SYN, - GOSSIP_DIGEST_ACK, - GOSSIP_DIGEST_ACK2, - @Deprecated DEFINITIONS_ANNOUNCE, - DEFINITIONS_UPDATE, - TRUNCATE - { - public long getTimeout() - { - return DatabaseDescriptor.getTruncateRpcTimeout(); - } - }, - SCHEMA_CHECK, - @Deprecated INDEX_SCAN, - REPLICATION_FINISHED, - INTERNAL_RESPONSE, // responses to internal calls - COUNTER_MUTATION - { - public long getTimeout() - { - return DatabaseDescriptor.getCounterWriteRpcTimeout(); - } - }, - @Deprecated STREAMING_REPAIR_REQUEST, - @Deprecated STREAMING_REPAIR_RESPONSE, - SNAPSHOT, // Similar to nt snapshot - MIGRATION_REQUEST, - GOSSIP_SHUTDOWN, - _TRACE, // dummy verb so we can use MS.droppedMessagesMap - ECHO, - REPAIR_MESSAGE, - PAXOS_PREPARE - { - public long getTimeout() - { - return DatabaseDescriptor.getWriteRpcTimeout(); - } - }, - PAXOS_PROPOSE - { - public long getTimeout() - { - return DatabaseDescriptor.getWriteRpcTimeout(); - } - }, - PAXOS_COMMIT - { - public long getTimeout() - { - return DatabaseDescriptor.getWriteRpcTimeout(); - } - }, - @Deprecated PAGED_RANGE - { - public long getTimeout() - { - return DatabaseDescriptor.getRangeRpcTimeout(); - } - }, - PING - { - public long getTimeout() - { - return DatabaseDescriptor.getPingTimeout(); - } - }, - - // UNUSED verbs were used as padding for backward/forward compatability before 4.0, - // but it wasn't quite as bullet/future proof as needed. We still need to keep these entries - // around, at least for a major rev or two (post-4.0). see CASSANDRA-13993 for a discussion. - // For now, though, the UNUSED are legacy values (placeholders, basically) that should only be used - // for correctly adding VERBs that need to be emergency additions to 3.0/3.11. - // We can reclaim them (their id's, to be correct) in future versions, if desireed, though. - UNUSED_2, - UNUSED_3, - UNUSED_4, - UNUSED_5, - _SAMPLE // dummy verb so we can use MS.droppedMessagesMap - ; - // add new verbs after the existing verbs, since we serialize by ordinal. - - private final int id; - Verb() - { - id = ordinal(); - } - - /** - * Unused, but it is an extension point for adding custom verbs - * @param id - */ - Verb(int id) - { - this.id = id; - } - - public long getTimeout() - { - return DatabaseDescriptor.getRpcTimeout(); - } - - public int getId() - { - return id; - } - private static final IntObjectMap idToVerbMap = new IntObjectOpenHashMap<>(values().length); - static - { - for (Verb v : values()) - { - Verb existing = idToVerbMap.put(v.getId(), v); - if (existing != null) - throw new IllegalArgumentException("cannot have two verbs that map to the same id: " + v + " and " + existing); - } - } - - public static Verb fromId(int id) - { - return idToVerbMap.get(id); - } + public static final MessagingService instance = new MessagingService(false); } - public static final EnumMap verbStages = new EnumMap(MessagingService.Verb.class) - {{ - put(Verb.MUTATION, Stage.MUTATION); - put(Verb.COUNTER_MUTATION, Stage.COUNTER_MUTATION); - put(Verb.READ_REPAIR, Stage.MUTATION); - put(Verb.HINT, Stage.MUTATION); - put(Verb.TRUNCATE, Stage.MUTATION); - put(Verb.PAXOS_PREPARE, Stage.MUTATION); - put(Verb.PAXOS_PROPOSE, Stage.MUTATION); - put(Verb.PAXOS_COMMIT, Stage.MUTATION); - put(Verb.BATCH_STORE, Stage.MUTATION); - put(Verb.BATCH_REMOVE, Stage.MUTATION); - - put(Verb.READ, Stage.READ); - put(Verb.RANGE_SLICE, Stage.READ); - put(Verb.INDEX_SCAN, Stage.READ); - put(Verb.PAGED_RANGE, Stage.READ); - - put(Verb.REQUEST_RESPONSE, Stage.REQUEST_RESPONSE); - put(Verb.INTERNAL_RESPONSE, Stage.INTERNAL_RESPONSE); - - put(Verb.STREAM_REPLY, Stage.MISC); // actually handled by FileStreamTask and streamExecutors - put(Verb.STREAM_REQUEST, Stage.MISC); - put(Verb.REPLICATION_FINISHED, Stage.MISC); - put(Verb.SNAPSHOT, Stage.MISC); - - put(Verb.TREE_REQUEST, Stage.ANTI_ENTROPY); - put(Verb.TREE_RESPONSE, Stage.ANTI_ENTROPY); - put(Verb.STREAMING_REPAIR_REQUEST, Stage.ANTI_ENTROPY); - put(Verb.STREAMING_REPAIR_RESPONSE, Stage.ANTI_ENTROPY); - put(Verb.REPAIR_MESSAGE, Stage.ANTI_ENTROPY); - put(Verb.GOSSIP_DIGEST_ACK, Stage.GOSSIP); - put(Verb.GOSSIP_DIGEST_ACK2, Stage.GOSSIP); - put(Verb.GOSSIP_DIGEST_SYN, Stage.GOSSIP); - put(Verb.GOSSIP_SHUTDOWN, Stage.GOSSIP); - - put(Verb.DEFINITIONS_UPDATE, Stage.MIGRATION); - put(Verb.SCHEMA_CHECK, Stage.MIGRATION); - put(Verb.MIGRATION_REQUEST, Stage.MIGRATION); - put(Verb.INDEX_SCAN, Stage.READ); - put(Verb.REPLICATION_FINISHED, Stage.MISC); - put(Verb.SNAPSHOT, Stage.MISC); - put(Verb.ECHO, Stage.GOSSIP); + public static MessagingService instance() + { + return MSHandle.instance; + } - put(Verb.UNUSED_2, Stage.INTERNAL_RESPONSE); - put(Verb.UNUSED_3, Stage.INTERNAL_RESPONSE); - put(Verb.UNUSED_4, Stage.INTERNAL_RESPONSE); - put(Verb.UNUSED_5, Stage.INTERNAL_RESPONSE); + public final SocketFactory socketFactory = new SocketFactory(); + public final LatencySubscribers latencySubscribers = new LatencySubscribers(); + public final RequestCallbacks callbacks = new RequestCallbacks(this); - put(Verb.PING, Stage.READ); - }}; + // a public hook for filtering messages intended for delivery to this node + public final InboundSink inboundSink = new InboundSink(this); - /** - * Messages we receive from peers have a Verb that tells us what kind of message it is. - * Most of the time, this is enough to determine how to deserialize the message payload. - * The exception is the REQUEST_RESPONSE verb, which just means "a reply to something you told me to do." - * Traditionally, this was fine since each VerbHandler knew what type of payload it expected, and - * handled the deserialization itself. Now that we do that in ITC, to avoid the extra copy to an - * intermediary byte[] (See CASSANDRA-3716), we need to wire that up to the CallbackInfo object - * (see below). - */ - public static final EnumMap> verbSerializers = new EnumMap>(Verb.class) - {{ - put(Verb.REQUEST_RESPONSE, CallbackDeterminedSerializer.instance); - put(Verb.INTERNAL_RESPONSE, CallbackDeterminedSerializer.instance); + // the inbound global reserve limits and associated wait queue + private final InboundMessageHandlers.GlobalResourceLimits inboundGlobalReserveLimits = new InboundMessageHandlers.GlobalResourceLimits( + new ResourceLimits.Concurrent(DatabaseDescriptor.getInternodeApplicationReceiveQueueReserveGlobalCapacityInBytes())); - put(Verb.MUTATION, Mutation.serializer); - put(Verb.READ_REPAIR, Mutation.serializer); - put(Verb.READ, ReadCommand.serializer); - put(Verb.RANGE_SLICE, ReadCommand.serializer); - put(Verb.PAGED_RANGE, ReadCommand.serializer); - put(Verb.BOOTSTRAP_TOKEN, BootStrapper.StringSerializer.instance); - put(Verb.REPAIR_MESSAGE, RepairMessage.serializer); - put(Verb.GOSSIP_DIGEST_ACK, GossipDigestAck.serializer); - put(Verb.GOSSIP_DIGEST_ACK2, GossipDigestAck2.serializer); - put(Verb.GOSSIP_DIGEST_SYN, GossipDigestSyn.serializer); - put(Verb.DEFINITIONS_UPDATE, MigrationManager.MigrationsSerializer.instance); - put(Verb.TRUNCATE, Truncation.serializer); - put(Verb.REPLICATION_FINISHED, null); - put(Verb.COUNTER_MUTATION, CounterMutation.serializer); - put(Verb.SNAPSHOT, SnapshotCommand.serializer); - put(Verb.ECHO, EchoMessage.serializer); - put(Verb.PAXOS_PREPARE, Commit.serializer); - put(Verb.PAXOS_PROPOSE, Commit.serializer); - put(Verb.PAXOS_COMMIT, Commit.serializer); - put(Verb.HINT, HintMessage.serializer); - put(Verb.BATCH_STORE, Batch.serializer); - put(Verb.BATCH_REMOVE, UUIDSerializer.serializer); - put(Verb.PING, PingMessage.serializer); - }}; + // the socket bindings we accept incoming connections on + private final InboundSockets inboundSockets = new InboundSockets(new InboundConnectionSettings() + .withHandlers(this::getInbound) + .withSocketFactory(socketFactory)); - /** - * A Map of what kind of serializer to wire up to a REQUEST_RESPONSE callback, based on outbound Verb. - */ - public static final EnumMap> callbackDeserializers = new EnumMap>(Verb.class) - {{ - put(Verb.MUTATION, WriteResponse.serializer); - put(Verb.HINT, HintResponse.serializer); - put(Verb.READ_REPAIR, WriteResponse.serializer); - put(Verb.COUNTER_MUTATION, WriteResponse.serializer); - put(Verb.RANGE_SLICE, ReadResponse.serializer); - put(Verb.PAGED_RANGE, ReadResponse.serializer); - put(Verb.READ, ReadResponse.serializer); - put(Verb.TRUNCATE, TruncateResponse.serializer); - put(Verb.SNAPSHOT, null); + // a public hook for filtering messages intended for delivery to another node + public final OutboundSink outboundSink = new OutboundSink(this::doSend); - put(Verb.MIGRATION_REQUEST, MigrationManager.MigrationsSerializer.instance); - put(Verb.SCHEMA_CHECK, UUIDSerializer.serializer); - put(Verb.BOOTSTRAP_TOKEN, BootStrapper.StringSerializer.instance); - put(Verb.REPLICATION_FINISHED, null); + final ResourceLimits.Limit outboundGlobalReserveLimit = + new ResourceLimits.Concurrent(DatabaseDescriptor.getInternodeApplicationSendQueueReserveGlobalCapacityInBytes()); - put(Verb.PAXOS_PREPARE, PrepareResponse.serializer); - put(Verb.PAXOS_PROPOSE, BooleanSerializer.serializer); + // back-pressure implementation + private final BackPressureStrategy backPressure = DatabaseDescriptor.getBackPressureStrategy(); - put(Verb.BATCH_STORE, WriteResponse.serializer); - put(Verb.BATCH_REMOVE, WriteResponse.serializer); - }}; + private volatile boolean isShuttingDown; - /* This records all the results mapped by message Id */ - private final ExpiringMap callbacks; + @VisibleForTesting + MessagingService(boolean testOnly) + { + super(testOnly); + OutboundConnections.scheduleUnusedConnectionMonitoring(this, ScheduledExecutors.scheduledTasks, 1L, TimeUnit.HOURS); + } /** - * a placeholder class that means "deserialize using the callback." We can't implement this without - * special-case code in InboundTcpConnection because there is no way to pass the message id to IVersionedSerializer. + * Send a non-mutation message to a given endpoint. This method specifies a callback + * which is invoked with the actual response. + * + * @param message message to be sent. + * @param to endpoint to which the message needs to be sent + * @param cb callback interface which is used to pass the responses or + * suggest that a timeout occurred to the invoker of the send(). */ - static class CallbackDeterminedSerializer implements IVersionedSerializer + public void sendWithCallback(Message message, InetAddressAndPort to, RequestCallback cb) { - public static final CallbackDeterminedSerializer instance = new CallbackDeterminedSerializer(); - - public Object deserialize(DataInputPlus in, int version) throws IOException - { - throw new UnsupportedOperationException(); - } - - public void serialize(Object o, DataOutputPlus out, int version) throws IOException - { - throw new UnsupportedOperationException(); - } - - public long serializedSize(Object o, int version) - { - throw new UnsupportedOperationException(); - } + sendWithCallback(message, to, cb, null); } - public static IVersionedSerializer getVerbSerializer(Verb verb, int id) + public void sendWithCallback(Message message, InetAddressAndPort to, RequestCallback cb, ConnectionType specifyConnection) { - IVersionedSerializer serializer = verbSerializers.get(verb); - if (serializer instanceof MessagingService.CallbackDeterminedSerializer) - { - CallbackInfo callback = MessagingService.instance().getRegisteredCallback(id); - if (callback == null) - return null; - - serializer = callback.serializer; - } - return serializer; + callbacks.addWithExpiration(cb, message, to); + updateBackPressureOnSend(to, cb, message); + if (cb.invokeOnFailure() && !message.callBackOnFailure()) + message = message.withCallBackOnFailure(); + send(message, to, specifyConnection); } - /* Lookup table for registering message handlers based on the verb. */ - private final Map verbHandlers; - - @VisibleForTesting - public final ConcurrentMap channelManagers = new NonBlockingHashMap<>(); - final List serverChannels = Lists.newArrayList(); - - private static final Logger logger = LoggerFactory.getLogger(MessagingService.class); - private static final int LOG_DROPPED_INTERVAL_IN_MS = 5000; - - private final SimpleCondition listenGate; - /** - * Verbs it's okay to drop if the request has been queued longer than the request timeout. These - * all correspond to client requests or something triggered by them; we don't want to - * drop internal messages like bootstrap or repair notifications. + * Send a mutation message or a Paxos Commit to a given endpoint. This method specifies a callback + * which is invoked with the actual response. + * Also holds the message (only mutation messages) to determine if it + * needs to trigger a hint (uses StorageProxy for that). + * + * @param message message to be sent. + * @param to endpoint to which the message needs to be sent + * @param handler callback interface which is used to pass the responses or + * suggest that a timeout occurred to the invoker of the send(). */ - public static final EnumSet DROPPABLE_VERBS = EnumSet.of(Verb._TRACE, - Verb._SAMPLE, - Verb.MUTATION, - Verb.COUNTER_MUTATION, - Verb.HINT, - Verb.READ_REPAIR, - Verb.READ, - Verb.RANGE_SLICE, - Verb.PAGED_RANGE, - Verb.REQUEST_RESPONSE, - Verb.BATCH_STORE, - Verb.BATCH_REMOVE); - - private static final class DroppedMessages - { - final DroppedMessageMetrics metrics; - final AtomicInteger droppedInternal; - final AtomicInteger droppedCrossNode; - - DroppedMessages(Verb verb) - { - this(new DroppedMessageMetrics(verb)); - } - - DroppedMessages(DroppedMessageMetrics metrics) - { - this.metrics = metrics; - this.droppedInternal = new AtomicInteger(0); - this.droppedCrossNode = new AtomicInteger(0); - } - } - - @VisibleForTesting - public void resetDroppedMessagesMap(String scope) + public void sendWriteWithCallback(Message message, Replica to, AbstractWriteResponseHandler handler, boolean allowHints) { - for (Verb verb : droppedMessagesMap.keySet()) - droppedMessagesMap.put(verb, new DroppedMessages(new DroppedMessageMetrics(metricName -> { - return new CassandraMetricsRegistry.MetricName("DroppedMessages", metricName, scope); - }))); + assert message.callBackOnFailure(); + callbacks.addWithExpiration(handler, message, to, handler.consistencyLevel(), allowHints); + updateBackPressureOnSend(to.endpoint(), handler, message); + send(message, to.endpoint(), null); } - // total dropped message counts for server lifetime - private final Map droppedMessagesMap = new EnumMap<>(Verb.class); - - private final List subscribers = new ArrayList(); - - // protocol versions of the other nodes in the cluster - private final ConcurrentMap versions = new NonBlockingHashMap<>(); - - // message sinks are a testing hook - private final Set messageSinks = new CopyOnWriteArraySet<>(); - - // back-pressure implementation - private final BackPressureStrategy backPressure = DatabaseDescriptor.getBackPressureStrategy(); - - private static class MSHandle + /** + * Send a message to a given endpoint. This method adheres to the fire and forget + * style messaging. + * + * @param message messages to be sent. + * @param to endpoint to which the message needs to be sent + */ + public void send(Message message, InetAddressAndPort to) { - public static final MessagingService instance = new MessagingService(false); + send(message, to, null); } - public static MessagingService instance() + public void send(Message message, InetAddressAndPort to, ConnectionType specifyConnection) { - return MSHandle.instance; - } + if (logger.isTraceEnabled()) + { + logger.trace("{} sending {} to {}@{}", FBUtilities.getBroadcastAddressAndPort(), message.verb(), message.id(), to); - private static class MSTestHandle - { - public static final MessagingService instance = new MessagingService(true); - } + if (to.equals(FBUtilities.getBroadcastAddressAndPort())) + logger.trace("Message-to-self {} going over MessagingService", message); + } - static MessagingService test() - { - return MSTestHandle.instance; + outboundSink.accept(message, to, specifyConnection); } - private MessagingService(boolean testOnly) + private void doSend(Message message, InetAddressAndPort to, ConnectionType specifyConnection) { - for (Verb verb : DROPPABLE_VERBS) - droppedMessagesMap.put(verb, new DroppedMessages(verb)); - - listenGate = new SimpleCondition(); - verbHandlers = new EnumMap<>(Verb.class); - if (!testOnly) + // expire the callback if the message failed to enqueue (failed to establish a connection or exceeded queue capacity) + while (true) { - Runnable logDropped = new Runnable() + OutboundConnections connections = getOutbound(to); + try { - public void run() - { - logDroppedMessages(); - } - }; - ScheduledExecutors.scheduledTasks.scheduleWithFixedDelay(logDropped, LOG_DROPPED_INTERVAL_IN_MS, LOG_DROPPED_INTERVAL_IN_MS, TimeUnit.MILLISECONDS); - } - - Function>, ?> timeoutReporter = new Function>, Object>() - { - public Object apply(Pair> pair) + connections.enqueue(message, specifyConnection); + return; + } + catch (ClosedChannelException e) { - final CallbackInfo expiredCallbackInfo = pair.right.value; + if (isShuttingDown) + return; // just drop the message, and let others clean up - maybeAddLatency(expiredCallbackInfo.callback, expiredCallbackInfo.target, pair.right.timeout); - - ConnectionMetrics.totalTimeouts.mark(); - markTimeout(expiredCallbackInfo.target); - - if (expiredCallbackInfo.callback.supportsBackPressure()) - { - updateBackPressureOnReceive(expiredCallbackInfo.target, expiredCallbackInfo.callback, true); - } - - if (expiredCallbackInfo.isFailureCallback()) - { - StageManager.getStage(Stage.INTERNAL_RESPONSE).submit(new Runnable() - { - @Override - public void run() - { - ((IAsyncCallbackWithFailure)expiredCallbackInfo.callback).onFailure(expiredCallbackInfo.target, RequestFailureReason.UNKNOWN); - } - }); - } - - if (expiredCallbackInfo.shouldHint()) - { - WriteCallbackInfo writeCallbackInfo = ((WriteCallbackInfo) expiredCallbackInfo); - Mutation mutation = writeCallbackInfo.mutation(); - return StorageProxy.submitHint(mutation, writeCallbackInfo.getReplica(), null); - } - - return null; + // remove the connection and try again + channelManagers.remove(to, connections); } - }; - - callbacks = new ExpiringMap<>(DatabaseDescriptor.getMinRpcTimeout(), timeoutReporter); - - if (!testOnly) - { - MBeanWrapper.instance.registerMBean(this, MBEAN_NAME); } } - public void addMessageSink(IMessageSink sink) - { - messageSinks.add(sink); - } - - public void removeMessageSink(IMessageSink sink) - { - messageSinks.remove(sink); - } - - public void clearMessageSinks() - { - messageSinks.clear(); - } - /** * Updates the back-pressure state on sending to the given host if enabled and the given message callback supports it. * @@ -657,7 +353,7 @@ public void clearMessageSinks() * @param callback The message callback. * @param message The actual message. */ - public void updateBackPressureOnSend(InetAddressAndPort host, IAsyncCallback callback, MessageOut message) + void updateBackPressureOnSend(InetAddressAndPort host, RequestCallback callback, Message message) { if (DatabaseDescriptor.backPressureEnabled() && callback.supportsBackPressure()) { @@ -674,7 +370,7 @@ public void updateBackPressureOnSend(InetAddressAndPort host, IAsyncCallback cal * @param callback The message callback. * @param timeout True if updated following a timeout, false otherwise. */ - public void updateBackPressureOnReceive(InetAddressAndPort host, IAsyncCallback callback, boolean timeout) + void updateBackPressureOnReceive(InetAddressAndPort host, RequestCallback callback, boolean timeout) { if (DatabaseDescriptor.backPressureEnabled() && callback.supportsBackPressure()) { @@ -701,1027 +397,199 @@ public void applyBackPressure(Iterable hosts, long timeoutIn { if (DatabaseDescriptor.backPressureEnabled()) { - Set states = new HashSet(); + Set states = new HashSet<>(); for (InetAddressAndPort host : hosts) { if (host.equals(FBUtilities.getBroadcastAddressAndPort())) continue; - OutboundMessagingPool pool = getMessagingConnection(host); - if (pool != null) - states.add(pool.getBackPressureState()); + states.add(getOutbound(host).getBackPressureState()); } - backPressure.apply(states, timeoutInNanos, TimeUnit.NANOSECONDS); + //noinspection unchecked + backPressure.apply(states, timeoutInNanos, NANOSECONDS); } } BackPressureState getBackPressureState(InetAddressAndPort host) { - OutboundMessagingPool messagingConnection = getMessagingConnection(host); - return messagingConnection != null ? messagingConnection.getBackPressureState() : null; + return getOutbound(host).getBackPressureState(); } - void markTimeout(InetAddressAndPort addr) + void markExpiredCallback(InetAddressAndPort addr) { - OutboundMessagingPool conn = channelManagers.get(addr); + OutboundConnections conn = channelManagers.get(addr); if (conn != null) - conn.incrementTimeout(); + conn.incrementExpiredCallbackCount(); } /** - * Track latency information for the dynamic snitch + * Only to be invoked once we believe the endpoint will never be contacted again. * - * @param cb the callback associated with this message -- this lets us know if it's a message type we're interested in - * @param address the host that replied to the message - * @param latency + * We close the connection after a five minute delay, to give asynchronous operations a chance to terminate */ - public void maybeAddLatency(IAsyncCallback cb, InetAddressAndPort address, long latency) + public void closeOutbound(InetAddressAndPort to) { - if (cb.isLatencyForSnitch()) - addLatency(address, latency); + OutboundConnections pool = channelManagers.get(to); + if (pool != null) + pool.scheduleClose(5L, MINUTES, true) + .addListener(future -> channelManagers.remove(to, pool)); } - public void addLatency(InetAddressAndPort address, long latency) + /** + * Only to be invoked once we believe the connections will never be used again. + */ + void closeOutboundNow(OutboundConnections connections) { - for (ILatencySubscriber subscriber : subscribers) - subscriber.receiveTiming(address, latency); + connections.close(true).addListener( + future -> channelManagers.remove(connections.template().to, connections) + ); } /** - * called from gossiper when it notices a node is not responding. + * Only to be invoked once we believe the connections will never be used again. */ - public void convict(InetAddressAndPort ep) + public void removeInbound(InetAddressAndPort from) { - logger.trace("Resetting pool for {}", ep); - reset(ep); + InboundMessageHandlers handlers = messageHandlers.remove(from); + if (null != handlers) + handlers.releaseMetrics(); } - public void listen() + /** + * Closes any current open channel/connection to the endpoint, but does not cause any message loss, and we will + * try to re-establish connections immediately + */ + public void interruptOutbound(InetAddressAndPort to) { - listen(DatabaseDescriptor.getInternodeMessagingEncyptionOptions()); + OutboundConnections pool = channelManagers.get(to); + if (pool != null) + pool.interrupt(); } - public void listen(ServerEncryptionOptions serverEncryptionOptions) + /** + * Reconnect to the peer using the given {@code addr}. Outstanding messages in each channel will be sent on the + * current channel. Typically this function is used for something like EC2 public IP addresses which need to be used + * for communication between EC2 regions. + * + * @param address IP Address to identify the peer + * @param preferredAddress IP Address to use (and prefer) going forward for connecting to the peer + */ + @SuppressWarnings("UnusedReturnValue") + public Future maybeReconnectWithNewIp(InetAddressAndPort address, InetAddressAndPort preferredAddress) { - callbacks.reset(); // hack to allow tests to stop/restart MS - listen(FBUtilities.getLocalAddressAndPort(), serverEncryptionOptions); - if (shouldListenOnBroadcastAddress()) - listen(FBUtilities.getBroadcastAddressAndPort(), serverEncryptionOptions); - listenGate.signalAll(); - } + if (!SystemKeyspace.updatePreferredIP(address, preferredAddress)) + return null; - public static boolean shouldListenOnBroadcastAddress() - { - return DatabaseDescriptor.shouldListenOnBroadcastAddress() - && !FBUtilities.getLocalAddressAndPort().equals(FBUtilities.getBroadcastAddressAndPort()); + OutboundConnections messagingPool = channelManagers.get(address); + if (messagingPool != null) + return messagingPool.reconnectWithNewIp(preferredAddress); + + return null; } /** - * Listen on the specified port. - * - * @param localEp InetAddressAndPort whose port to listen on. + * Wait for callbacks and don't allow any more to be created (since they could require writing hints) */ - private void listen(InetAddressAndPort localEp, ServerEncryptionOptions serverEncryptionOptions) throws ConfigurationException + public void shutdown() { - IInternodeAuthenticator authenticator = DatabaseDescriptor.getInternodeAuthenticator(); - int receiveBufferSize = DatabaseDescriptor.getInternodeRecvBufferSize(); + shutdown(1L, MINUTES, true, true); + } - // this is the legacy socket, for letting peer nodes that haven't upgrade yet connect to this node. - // should only occur during cluster upgrade. we can remove this block at 5.0! - if (serverEncryptionOptions.enabled && serverEncryptionOptions.enable_legacy_ssl_storage_port) + public void shutdown(long timeout, TimeUnit units, boolean shutdownGracefully, boolean shutdownExecutors) + { + isShuttingDown = true; + logger.info("Waiting for messaging service to quiesce"); + // We may need to schedule hints on the mutation stage, so it's erroneous to shut down the mutation stage first + assert !StageManager.getStage(MUTATION).isShutdown(); + + if (shutdownGracefully) + { + callbacks.shutdownGracefully(); + List> closing = new ArrayList<>(); + for (OutboundConnections pool : channelManagers.values()) + closing.add(pool.close(true)); + + long deadline = System.nanoTime() + units.toNanos(timeout); + maybeFail(() -> new FutureCombiner(closing).get(timeout, units), + () -> { + List inboundExecutors = new ArrayList<>(); + inboundSockets.close(synchronizedList(inboundExecutors)::add).get(); + ExecutorUtils.awaitTermination(1L, TimeUnit.MINUTES, inboundExecutors); + }, + () -> { + if (shutdownExecutors) + shutdownExecutors(deadline); + }, + () -> callbacks.awaitTerminationUntil(deadline), + inboundSink::clear, + outboundSink::clear); + } + else { - // clone the encryption options, and explicitly set the optional field to false - // (do not allow non-TLS connections on the legacy ssl port) - ServerEncryptionOptions legacyEncOptions = new ServerEncryptionOptions(serverEncryptionOptions); - legacyEncOptions.optional = false; + callbacks.shutdownNow(false); + List> closing = new ArrayList<>(); + List inboundExecutors = synchronizedList(new ArrayList()); + closing.add(inboundSockets.close(inboundExecutors::add)); + for (OutboundConnections pool : channelManagers.values()) + closing.add(pool.close(false)); - InetAddressAndPort localAddr = InetAddressAndPort.getByAddressOverrideDefaults(localEp.address, DatabaseDescriptor.getSSLStoragePort()); - ChannelGroup channelGroup = new DefaultChannelGroup("LegacyEncryptedInternodeMessagingGroup", NettyFactory.executorForChannelGroups()); - InboundInitializer initializer = new InboundInitializer(authenticator, legacyEncOptions, channelGroup); - Channel encryptedChannel = NettyFactory.instance.createInboundChannel(localAddr, initializer, receiveBufferSize); - serverChannels.add(new ServerChannel(encryptedChannel, channelGroup, localAddr, ServerChannel.SecurityLevel.REQUIRED)); + long deadline = System.nanoTime() + units.toNanos(timeout); + maybeFail(() -> new FutureCombiner(closing).get(timeout, units), + () -> { + if (shutdownExecutors) + shutdownExecutors(deadline); + }, + () -> ExecutorUtils.awaitTermination(timeout, units, inboundExecutors), + () -> callbacks.awaitTerminationUntil(deadline), + inboundSink::clear, + outboundSink::clear); } + } - // this is for the socket that can be plain, only ssl, or optional plain/ssl - assert localEp.port == DatabaseDescriptor.getStoragePort() : String.format("Local endpoint port %d doesn't match YAML configured port %d%n", localEp.port, DatabaseDescriptor.getStoragePort()); - InetAddressAndPort localAddr = InetAddressAndPort.getByAddressOverrideDefaults(localEp.address, DatabaseDescriptor.getStoragePort()); - ChannelGroup channelGroup = new DefaultChannelGroup("InternodeMessagingGroup", NettyFactory.executorForChannelGroups()); - InboundInitializer initializer = new InboundInitializer(authenticator, serverEncryptionOptions, channelGroup); - Channel channel = NettyFactory.instance.createInboundChannel(localAddr, initializer, receiveBufferSize); - ServerChannel.SecurityLevel securityLevel = !serverEncryptionOptions.enabled ? ServerChannel.SecurityLevel.NONE : - serverEncryptionOptions.optional ? ServerChannel.SecurityLevel.OPTIONAL : - ServerChannel.SecurityLevel.REQUIRED; - serverChannels.add(new ServerChannel(channel, channelGroup, localAddr, securityLevel)); + private void shutdownExecutors(long deadlineNanos) throws TimeoutException, InterruptedException + { + socketFactory.shutdownNow(); + socketFactory.awaitTerminationUntil(deadlineNanos); } - /** - * A simple struct to wrap up the the components needed for each listening socket. - *

- * The {@link #securityLevel} is captured independently of the {@link #channel} as there's no real way to inspect a s - * erver-side 'channel' to check if it using TLS or not (the channel's configured pipeline will only apply to - * connections that get created, so it's not inspectible). {@link #securityLevel} is really only used for testing, anyway. - */ - @VisibleForTesting - static class ServerChannel - { - /** - * Declares the type of TLS used with the channel. - */ - enum SecurityLevel { NONE, OPTIONAL, REQUIRED } - - /** - * The base {@link Channel} that is doing the spcket listen/accept. - */ - private final Channel channel; - - /** - * A group of the open, inbound {@link Channel}s connected to this node. This is mostly interesting so that all of - * the inbound connections/channels can be closed when the listening socket itself is being closed. - */ - private final ChannelGroup connectedChannels; - private final InetAddressAndPort address; - private final SecurityLevel securityLevel; - - private ServerChannel(Channel channel, ChannelGroup channelGroup, InetAddressAndPort address, SecurityLevel securityLevel) - { - this.channel = channel; - this.connectedChannels = channelGroup; - this.address = address; - this.securityLevel = securityLevel; - } - - void close() - { - if (channel.isOpen()) - channel.close().awaitUninterruptibly(); - connectedChannels.close().awaitUninterruptibly(); - } - - int size() - { - return connectedChannels.size(); - } - - /** - * For testing only! - */ - Channel getChannel() - { - return channel; - } - - InetAddressAndPort getAddress() - { - return address; - } - - SecurityLevel getSecurityLevel() - { - return securityLevel; - } - } - - public void waitUntilListening() - { - try - { - listenGate.await(); - } - catch (InterruptedException ie) - { - logger.trace("await interrupted"); - } - } - - public boolean isListening() - { - return listenGate.isSignaled(); - } - - - public void destroyConnectionPool(InetAddressAndPort to) - { - OutboundMessagingPool pool = channelManagers.remove(to); - if (pool != null) - pool.close(true); - } - - /** - * Reconnect to the peer using the given {@code addr}. Outstanding messages in each channel will be sent on the - * current channel. Typically this function is used for something like EC2 public IP addresses which need to be used - * for communication between EC2 regions. - * - * @param address IP Address to identify the peer - * @param preferredAddress IP Address to use (and prefer) going forward for connecting to the peer - */ - public void reconnectWithNewIp(InetAddressAndPort address, InetAddressAndPort preferredAddress) - { - SystemKeyspace.updatePreferredIP(address, preferredAddress); - - OutboundMessagingPool messagingPool = channelManagers.get(address); - if (messagingPool != null) - messagingPool.reconnectWithNewIp(InetAddressAndPort.getByAddressOverrideDefaults(preferredAddress.address, portFor(address))); - } - - private void reset(InetAddressAndPort address) - { - OutboundMessagingPool messagingPool = channelManagers.remove(address); - if (messagingPool != null) - messagingPool.close(false); - } - - public InetAddressAndPort getCurrentEndpoint(InetAddressAndPort publicAddress) - { - OutboundMessagingPool messagingPool = getMessagingConnection(publicAddress); - return messagingPool != null ? messagingPool.getPreferredRemoteAddr() : null; - } - - /** - * Register a verb and the corresponding verb handler with the - * Messaging Service. - * - * @param verb - * @param verbHandler handler for the specified verb - */ - public void registerVerbHandlers(Verb verb, IVerbHandler verbHandler) - { - assert !verbHandlers.containsKey(verb); - verbHandlers.put(verb, verbHandler); - } - - /** - * SHOULD ONLY BE USED FOR TESTING!! - */ - public void removeVerbHandler(Verb verb) - { - verbHandlers.remove(verb); - } - - /** - * This method returns the verb handler associated with the registered - * verb. If no handler has been registered then null is returned. - * - * @param type for which the verb handler is sought - * @return a reference to IVerbHandler which is the handler for the specified verb - */ - public IVerbHandler getVerbHandler(Verb type) - { - return verbHandlers.get(type); - } - - public int addWriteCallback(IAsyncCallback cb, MessageOut message, InetAddressAndPort to, long timeout, boolean failureCallback) - { - assert message.verb != Verb.MUTATION; // mutations need to call the overload with a ConsistencyLevel - int messageId = nextId(); - CallbackInfo previous = callbacks.put(messageId, new CallbackInfo(to, cb, callbackDeserializers.get(message.verb), failureCallback), timeout); - assert previous == null : String.format("Callback already exists for id %d! (%s)", messageId, previous); - return messageId; - } - - public int addWriteCallback(IAsyncCallback cb, - MessageOut message, - Replica to, - long timeout, - ConsistencyLevel consistencyLevel, - boolean allowHints) - { - assert message.verb == Verb.MUTATION - || message.verb == Verb.COUNTER_MUTATION - || message.verb == Verb.PAXOS_COMMIT; - int messageId = nextId(); - - CallbackInfo previous = callbacks.put(messageId, - new WriteCallbackInfo(to, - cb, - message, - callbackDeserializers.get(message.verb), - consistencyLevel, - allowHints), - timeout); - assert previous == null : String.format("Callback already exists for id %d! (%s)", messageId, previous); - return messageId; - } - - private static final AtomicInteger idGen = new AtomicInteger(0); - - private static int nextId() - { - return idGen.incrementAndGet(); - } - - public int sendRR(MessageOut message, InetAddressAndPort to, IAsyncCallback cb) - { - return sendRR(message, to, cb, message.getTimeout(), false); - } - - public int sendRRWithFailure(MessageOut message, InetAddressAndPort to, IAsyncCallbackWithFailure cb) - { - return sendRR(message, to, cb, message.getTimeout(), true); - } - - /** - * Send a non-mutation message to a given endpoint. This method specifies a callback - * which is invoked with the actual response. - * - * @param message message to be sent. - * @param to endpoint to which the message needs to be sent - * @param cb callback interface which is used to pass the responses or - * suggest that a timeout occurred to the invoker of the send(). - * @param timeout the timeout used for expiration - * @return an reference to message id used to match with the result - */ - public int sendRR(MessageOut message, InetAddressAndPort to, IAsyncCallback cb, long timeout, boolean failureCallback) - { - int id = addWriteCallback(cb, message, to, timeout, failureCallback); - updateBackPressureOnSend(to, cb, message); - sendOneWay(failureCallback ? message.withParameter(ParameterType.FAILURE_CALLBACK, ONE_BYTE) : message, id, to); - return id; - } - - /** - * Send a mutation message or a Paxos Commit to a given endpoint. This method specifies a callback - * which is invoked with the actual response. - * Also holds the message (only mutation messages) to determine if it - * needs to trigger a hint (uses StorageProxy for that). - * - * @param message message to be sent. - * @param to endpoint to which the message needs to be sent - * @param handler callback interface which is used to pass the responses or - * suggest that a timeout occurred to the invoker of the send(). - * @return an reference to message id used to match with the result - */ - public int sendWriteRR(MessageOut message, - Replica to, - AbstractWriteResponseHandler handler, - boolean allowHints) - { - int id = addWriteCallback(handler, message, to, message.getTimeout(), handler.consistencyLevel(), allowHints); - updateBackPressureOnSend(to.endpoint(), handler, message); - sendOneWay(message.withParameter(ParameterType.FAILURE_CALLBACK, ONE_BYTE), id, to.endpoint()); - return id; - } - - public void sendOneWay(MessageOut message, InetAddressAndPort to) - { - sendOneWay(message, nextId(), to); - } - - public void sendReply(MessageOut message, int id, InetAddressAndPort to) - { - sendOneWay(message, id, to); - } - - /** - * Send a message to a given endpoint. This method adheres to the fire and forget - * style messaging. - * - * @param message messages to be sent. - * @param to endpoint to which the message needs to be sent - */ - public void sendOneWay(MessageOut message, int id, InetAddressAndPort to) - { - if (logger.isTraceEnabled()) - logger.trace("{} sending {} to {}@{}", FBUtilities.getBroadcastAddressAndPort(), message.verb, id, to); - - if (to.equals(FBUtilities.getBroadcastAddressAndPort())) - logger.trace("Message-to-self {} going over MessagingService", message); - - // message sinks are a testing hook - for (IMessageSink ms : messageSinks) - if (!ms.allowOutgoingMessage(message, id, to)) - return; - - OutboundMessagingPool outboundMessagingPool = getMessagingConnection(to); - if (outboundMessagingPool != null) - outboundMessagingPool.sendMessage(message, id); - } - - public AsyncOneResponse sendRR(MessageOut message, InetAddressAndPort to) - { - AsyncOneResponse iar = new AsyncOneResponse(); - sendRR(message, to, iar); - return iar; - } - - public void register(ILatencySubscriber subcriber) - { - subscribers.add(subcriber); - } - - public void clearCallbacksUnsafe() - { - callbacks.reset(); - } - - /** - * Wait for callbacks and don't allow any more to be created (since they could require writing hints) - */ - public void shutdown() - { - shutdown(false); - } - - public void shutdown(boolean isTest) - { - logger.info("Waiting for messaging service to quiesce"); - // We may need to schedule hints on the mutation stage, so it's erroneous to shut down the mutation stage first - assert !StageManager.getStage(Stage.MUTATION).isShutdown(); - - // the important part - if (!callbacks.shutdownBlocking()) - logger.warn("Failed to wait for messaging service callbacks shutdown"); - - // attempt to humor tests that try to stop and restart MS - try - { - // first close the recieve channels - for (ServerChannel serverChannel : serverChannels) - serverChannel.close(); - - // now close the send channels - for (OutboundMessagingPool pool : channelManagers.values()) - pool.close(false); - - if (!isTest) - NettyFactory.instance.close(); - - clearMessageSinks(); - } - catch (Exception e) - { - throw new IOError(e); - } - } - - /** - * For testing only! - */ - void clearServerChannels() - { - serverChannels.clear(); - } - - public void receive(MessageIn message, int id) - { - TraceState state = Tracing.instance.initializeFromMessage(message); - if (state != null) - state.trace("{} message received from {}", message.verb, message.from); - - // message sinks are a testing hook - for (IMessageSink ms : messageSinks) - if (!ms.allowIncomingMessage(message, id)) - return; - - Runnable runnable = new MessageDeliveryTask(message, id); - LocalAwareExecutorService stage = StageManager.getStage(message.getMessageType()); - assert stage != null : "No stage for message type " + message.verb; - - stage.execute(runnable, ExecutorLocals.create(state)); - } - - public void setCallbackForTests(int messageId, CallbackInfo callback) - { - callbacks.put(messageId, callback); - } - - public CallbackInfo getRegisteredCallback(int messageId) - { - return callbacks.get(messageId); - } - - public CallbackInfo removeRegisteredCallback(int messageId) - { - return callbacks.remove(messageId); - } - - /** - * @return System.nanoTime() when callback was created. - */ - public long getRegisteredCallbackAge(int messageId) - { - return callbacks.getAge(messageId); - } - - public static void validateMagic(int magic) throws IOException - { - if (magic != PROTOCOL_MAGIC) - throw new IOException("invalid protocol header"); - } - - public static int getBits(int packed, int start, int count) + private OutboundConnections getOutbound(InetAddressAndPort to) { - return packed >>> (start + 1) - count & ~(-1 << count); + OutboundConnections connections = channelManagers.get(to); + if (connections == null) + connections = OutboundConnections.tryRegister(channelManagers, to, new OutboundConnectionSettings(to).withDefaults(ConnectionCategory.MESSAGING), backPressure.newState(to)); + return connections; } - /** - * @return the last version associated with address, or @param version if this is the first such version - */ - public int setVersion(InetAddressAndPort endpoint, int version) - { - logger.trace("Setting version {} for {}", version, endpoint); - - Integer v = versions.put(endpoint, version); - return v == null ? version : v; - } - - public void resetVersion(InetAddressAndPort endpoint) - { - logger.trace("Resetting version for {}", endpoint); - versions.remove(endpoint); - } - - /** - * Returns the messaging-version as announced by the given node but capped - * to the min of the version as announced by the node and {@link #current_version}. - */ - public int getVersion(InetAddressAndPort endpoint) - { - Integer v = versions.get(endpoint); - if (v == null) - { - // we don't know the version. assume current. we'll know soon enough if that was incorrect. - logger.trace("Assuming current protocol version for {}", endpoint); - return MessagingService.current_version; - } - else - return Math.min(v, MessagingService.current_version); - } - - public int getVersion(String endpoint) throws UnknownHostException - { - return getVersion(InetAddressAndPort.getByName(endpoint)); - } - - /** - * Returns the messaging-version exactly as announced by the given endpoint. - */ - public int getRawVersion(InetAddressAndPort endpoint) - { - Integer v = versions.get(endpoint); - if (v == null) - throw new IllegalStateException("getRawVersion() was called without checking knowsVersion() result first"); - return v; - } - - public boolean knowsVersion(InetAddressAndPort endpoint) - { - return versions.containsKey(endpoint); - } - - public void incrementDroppedMutations(Optional mutationOpt, long timeTaken) - { - if (mutationOpt.isPresent()) - { - updateDroppedMutationCount(mutationOpt.get()); - } - incrementDroppedMessages(Verb.MUTATION, timeTaken); - } - - public void incrementDroppedMessages(Verb verb) - { - incrementDroppedMessages(verb, false); - } - - public void incrementDroppedMessages(Verb verb, long timeTaken) - { - incrementDroppedMessages(verb, timeTaken, false); - } - - public void incrementDroppedMessages(MessageIn message, long timeTaken) - { - if (message.payload instanceof IMutation) - { - updateDroppedMutationCount((IMutation) message.payload); - } - incrementDroppedMessages(message.verb, timeTaken, message.isCrossNode()); - } - - public void incrementDroppedMessages(Verb verb, long timeTaken, boolean isCrossNode) - { - assert DROPPABLE_VERBS.contains(verb) : "Verb " + verb + " should not legally be dropped"; - incrementDroppedMessages(droppedMessagesMap.get(verb), timeTaken, isCrossNode); - } - - public void incrementDroppedMessages(Verb verb, boolean isCrossNode) - { - assert DROPPABLE_VERBS.contains(verb) : "Verb " + verb + " should not legally be dropped"; - incrementDroppedMessages(droppedMessagesMap.get(verb), isCrossNode); - } - - private void updateDroppedMutationCount(IMutation mutation) - { - assert mutation != null : "Mutation should not be null when updating dropped mutations count"; - - for (TableId tableId : mutation.getTableIds()) - { - ColumnFamilyStore cfs = Keyspace.open(mutation.getKeyspaceName()).getColumnFamilyStore(tableId); - if (cfs != null) - { - cfs.metric.droppedMutations.inc(); - } - } - } - - private void incrementDroppedMessages(DroppedMessages droppedMessages, long timeTaken, boolean isCrossNode) - { - if (isCrossNode) - droppedMessages.metrics.crossNodeDroppedLatency.update(timeTaken, TimeUnit.MILLISECONDS); - else - droppedMessages.metrics.internalDroppedLatency.update(timeTaken, TimeUnit.MILLISECONDS); - incrementDroppedMessages(droppedMessages, isCrossNode); - } - - private void incrementDroppedMessages(DroppedMessages droppedMessages, boolean isCrossNode) - { - droppedMessages.metrics.dropped.mark(); - if (isCrossNode) - droppedMessages.droppedCrossNode.incrementAndGet(); - else - droppedMessages.droppedInternal.incrementAndGet(); - } - - private void logDroppedMessages() - { - List logs = getDroppedMessagesLogs(); - for (String log : logs) - logger.info(log); - - if (logs.size() > 0) - StatusLogger.log(); - } - - @VisibleForTesting - List getDroppedMessagesLogs() - { - List ret = new ArrayList<>(); - for (Map.Entry entry : droppedMessagesMap.entrySet()) - { - Verb verb = entry.getKey(); - DroppedMessages droppedMessages = entry.getValue(); - - int droppedInternal = droppedMessages.droppedInternal.getAndSet(0); - int droppedCrossNode = droppedMessages.droppedCrossNode.getAndSet(0); - if (droppedInternal > 0 || droppedCrossNode > 0) - { - ret.add(String.format("%s messages were dropped in last %d ms: %d internal and %d cross node." - + " Mean internal dropped latency: %d ms and Mean cross-node dropped latency: %d ms", - verb, - LOG_DROPPED_INTERVAL_IN_MS, - droppedInternal, - droppedCrossNode, - TimeUnit.NANOSECONDS.toMillis((long)droppedMessages.metrics.internalDroppedLatency.getSnapshot().getMean()), - TimeUnit.NANOSECONDS.toMillis((long)droppedMessages.metrics.crossNodeDroppedLatency.getSnapshot().getMean()))); - } - } - return ret; - } - - - private static void handleIOExceptionOnClose(IOException e) throws IOException - { - // dirty hack for clean shutdown on OSX w/ Java >= 1.8.0_20 - // see https://bugs.openjdk.java.net/browse/JDK-8050499; - // also CASSANDRA-12513 - if (NativeLibrary.osType == NativeLibrary.OSType.MAC) - { - switch (e.getMessage()) - { - case "Unknown error: 316": - case "No such file or directory": - return; - } - } - - throw e; - } - - public Map getLargeMessagePendingTasks() - { - Map pendingTasks = new HashMap(channelManagers.size()); - for (Map.Entry entry : channelManagers.entrySet()) - pendingTasks.put(entry.getKey().toString(false), entry.getValue().largeMessageChannel.getPendingMessages()); - return pendingTasks; - } - - public Map getLargeMessageCompletedTasks() - { - Map completedTasks = new HashMap(channelManagers.size()); - for (Map.Entry entry : channelManagers.entrySet()) - completedTasks.put(entry.getKey().toString(false), entry.getValue().largeMessageChannel.getCompletedMessages()); - return completedTasks; - } - - public Map getLargeMessageDroppedTasks() - { - Map droppedTasks = new HashMap(channelManagers.size()); - for (Map.Entry entry : channelManagers.entrySet()) - droppedTasks.put(entry.getKey().toString(false), entry.getValue().largeMessageChannel.getDroppedMessages()); - return droppedTasks; - } - - public Map getSmallMessagePendingTasks() - { - Map pendingTasks = new HashMap(channelManagers.size()); - for (Map.Entry entry : channelManagers.entrySet()) - pendingTasks.put(entry.getKey().toString(false), entry.getValue().smallMessageChannel.getPendingMessages()); - return pendingTasks; - } - - public Map getSmallMessageCompletedTasks() - { - Map completedTasks = new HashMap(channelManagers.size()); - for (Map.Entry entry : channelManagers.entrySet()) - completedTasks.put(entry.getKey().toString(false), entry.getValue().smallMessageChannel.getCompletedMessages()); - return completedTasks; - } - - public Map getSmallMessageDroppedTasks() - { - Map droppedTasks = new HashMap(channelManagers.size()); - for (Map.Entry entry : channelManagers.entrySet()) - droppedTasks.put(entry.getKey().toString(false), entry.getValue().smallMessageChannel.getDroppedMessages()); - return droppedTasks; - } - - public Map getGossipMessagePendingTasks() - { - Map pendingTasks = new HashMap(channelManagers.size()); - for (Map.Entry entry : channelManagers.entrySet()) - pendingTasks.put(entry.getKey().toString(false), entry.getValue().gossipChannel.getPendingMessages()); - return pendingTasks; - } - - public Map getGossipMessageCompletedTasks() - { - Map completedTasks = new HashMap(channelManagers.size()); - for (Map.Entry entry : channelManagers.entrySet()) - completedTasks.put(entry.getKey().toString(false), entry.getValue().gossipChannel.getCompletedMessages()); - return completedTasks; - } - - public Map getGossipMessageDroppedTasks() - { - Map droppedTasks = new HashMap(channelManagers.size()); - for (Map.Entry entry : channelManagers.entrySet()) - droppedTasks.put(entry.getKey().toString(false), entry.getValue().gossipChannel.getDroppedMessages()); - return droppedTasks; - } - - public Map getLargeMessagePendingTasksWithPort() + InboundMessageHandlers getInbound(InetAddressAndPort from) { - Map pendingTasks = new HashMap(channelManagers.size()); - for (Map.Entry entry : channelManagers.entrySet()) - pendingTasks.put(entry.getKey().toString(), entry.getValue().largeMessageChannel.getPendingMessages()); - return pendingTasks; - } + InboundMessageHandlers handlers = messageHandlers.get(from); + if (null != handlers) + return handlers; - public Map getLargeMessageCompletedTasksWithPort() - { - Map completedTasks = new HashMap(channelManagers.size()); - for (Map.Entry entry : channelManagers.entrySet()) - completedTasks.put(entry.getKey().toString(), entry.getValue().largeMessageChannel.getCompletedMessages()); - return completedTasks; - } - - public Map getLargeMessageDroppedTasksWithPort() - { - Map droppedTasks = new HashMap(channelManagers.size()); - for (Map.Entry entry : channelManagers.entrySet()) - droppedTasks.put(entry.getKey().toString(), entry.getValue().largeMessageChannel.getDroppedMessages()); - return droppedTasks; - } - - public Map getSmallMessagePendingTasksWithPort() - { - Map pendingTasks = new HashMap(channelManagers.size()); - for (Map.Entry entry : channelManagers.entrySet()) - pendingTasks.put(entry.getKey().toString(), entry.getValue().smallMessageChannel.getPendingMessages()); - return pendingTasks; - } - - public Map getSmallMessageCompletedTasksWithPort() - { - Map completedTasks = new HashMap(channelManagers.size()); - for (Map.Entry entry : channelManagers.entrySet()) - completedTasks.put(entry.getKey().toString(), entry.getValue().smallMessageChannel.getCompletedMessages()); - return completedTasks; - } - - public Map getSmallMessageDroppedTasksWithPort() - { - Map droppedTasks = new HashMap(channelManagers.size()); - for (Map.Entry entry : channelManagers.entrySet()) - droppedTasks.put(entry.getKey().toString(), entry.getValue().smallMessageChannel.getDroppedMessages()); - return droppedTasks; - } - - public Map getGossipMessagePendingTasksWithPort() - { - Map pendingTasks = new HashMap(channelManagers.size()); - for (Map.Entry entry : channelManagers.entrySet()) - pendingTasks.put(entry.getKey().toString(), entry.getValue().gossipChannel.getPendingMessages()); - return pendingTasks; - } - - public Map getGossipMessageCompletedTasksWithPort() - { - Map completedTasks = new HashMap(channelManagers.size()); - for (Map.Entry entry : channelManagers.entrySet()) - completedTasks.put(entry.getKey().toString(), entry.getValue().gossipChannel.getCompletedMessages()); - return completedTasks; - } - - public Map getGossipMessageDroppedTasksWithPort() - { - Map droppedTasks = new HashMap(channelManagers.size()); - for (Map.Entry entry : channelManagers.entrySet()) - droppedTasks.put(entry.getKey().toString(), entry.getValue().gossipChannel.getDroppedMessages()); - return droppedTasks; - } - - public Map getDroppedMessages() - { - Map map = new HashMap<>(droppedMessagesMap.size()); - for (Map.Entry entry : droppedMessagesMap.entrySet()) - map.put(entry.getKey().toString(), (int) entry.getValue().metrics.dropped.getCount()); - return map; - } - - public long getTotalTimeouts() - { - return ConnectionMetrics.totalTimeouts.getCount(); - } - - public Map getTimeoutsPerHost() - { - Map result = new HashMap(channelManagers.size()); - for (Map.Entry entry : channelManagers.entrySet()) - { - String ip = entry.getKey().toString(false); - long recent = entry.getValue().getTimeouts(); - result.put(ip, recent); - } - return result; - } - - public Map getTimeoutsPerHostWithPort() - { - Map result = new HashMap(channelManagers.size()); - for (Map.Entry entry : channelManagers.entrySet()) - { - String ip = entry.getKey().toString(); - long recent = entry.getValue().getTimeouts(); - result.put(ip, recent); - } - return result; - } - - public Map getBackPressurePerHost() - { - Map map = new HashMap<>(channelManagers.size()); - for (Map.Entry entry : channelManagers.entrySet()) - map.put(entry.getKey().toString(false), entry.getValue().getBackPressureState().getBackPressureRateLimit()); - - return map; - } - - public Map getBackPressurePerHostWithPort() - { - Map map = new HashMap<>(channelManagers.size()); - for (Map.Entry entry : channelManagers.entrySet()) - map.put(entry.getKey().toString(false), entry.getValue().getBackPressureState().getBackPressureRateLimit()); - - return map; - } - - @Override - public void setBackPressureEnabled(boolean enabled) - { - DatabaseDescriptor.setBackPressureEnabled(enabled); - } - - @Override - public boolean isBackPressureEnabled() - { - return DatabaseDescriptor.backPressureEnabled(); - } - - public static IPartitioner globalPartitioner() - { - return StorageService.instance.getTokenMetadata().partitioner; - } - - public static void validatePartitioner(Collection> allBounds) - { - for (AbstractBounds bounds : allBounds) - validatePartitioner(bounds); - } - - public static void validatePartitioner(AbstractBounds bounds) - { - if (globalPartitioner() != bounds.left.getPartitioner()) - throw new AssertionError(String.format("Partitioner in bounds serialization. Expected %s, was %s.", - globalPartitioner().getClass().getName(), - bounds.left.getPartitioner().getClass().getName())); - } - - /** - * This method is used to determine the preferred IP & Port of a peer using the - * {@link OutboundMessagingPool} and SystemKeyspace. - */ - public InetAddressAndPort getPreferredRemoteAddr(InetAddressAndPort to) - { - OutboundMessagingPool pool = channelManagers.get(to); - return pool != null ? pool.getPreferredRemoteAddr() : SystemKeyspace.getPreferredIP(to); - } - - private OutboundMessagingPool getMessagingConnection(InetAddressAndPort to) - { - OutboundMessagingPool pool = channelManagers.get(to); - if (pool == null) - { - final boolean secure = isEncryptedConnection(to); - final int port = portFor(to, secure); - if (!DatabaseDescriptor.getInternodeAuthenticator().authenticate(to.address, port)) - return null; - - InetAddressAndPort preferredRemote = SystemKeyspace.getPreferredIP(to); - InetAddressAndPort local = FBUtilities.getBroadcastAddressAndPort(); - ServerEncryptionOptions encryptionOptions = secure ? DatabaseDescriptor.getInternodeMessagingEncyptionOptions() : null; - IInternodeAuthenticator authenticator = DatabaseDescriptor.getInternodeAuthenticator(); - - pool = new OutboundMessagingPool(preferredRemote, local, encryptionOptions, backPressure.newState(to), authenticator); - OutboundMessagingPool existing = channelManagers.putIfAbsent(to, pool); - if (existing != null) - { - pool.close(false); - pool = existing; - } - } - return pool; - } - - public int portFor(InetAddressAndPort addr) - { - final boolean secure = isEncryptedConnection(addr); - return portFor(addr, secure); - } - - private int portFor(InetAddressAndPort address, boolean secure) - { - if (!secure) - return address.port; - - Integer v = versions.get(address); - // if we don't know the version of the peer, assume it is 4.0 (or higher) as the only time is would be lower - // (as in a 3.x version) is during a cluster upgrade (from 3.x to 4.0). In that case the outbound connection will - // unfortunately fail - however the peer should connect to this node (at some point), and once we learn it's version, it'll be - // in versions map. thus, when we attempt to reconnect to that node, we'll have the version and we can get the correct port. - // we will be able to remove this logic at 5.0. - // Also as of 4.0 we will propagate the "regular" port (which will support both SSL and non-SSL) via gossip so - // for SSL and version 4.0 always connect to the gossiped port because if SSL is enabled it should ALWAYS - // listen for SSL on the "regular" port. - int version = v != null ? v.intValue() : VERSION_40; - return version < VERSION_40 ? DatabaseDescriptor.getSSLStoragePort() : address.port; + return messageHandlers.computeIfAbsent(from, addr -> + new InboundMessageHandlers(FBUtilities.getLocalAddressAndPort(), + addr, + DatabaseDescriptor.getInternodeApplicationReceiveQueueCapacityInBytes(), + DatabaseDescriptor.getInternodeApplicationReceiveQueueReserveEndpointCapacityInBytes(), + inboundGlobalReserveLimits, metrics, inboundSink) + ); } @VisibleForTesting - boolean isConnected(InetAddressAndPort address, MessageOut messageOut) + boolean isConnected(InetAddressAndPort address, Message messageOut) { - OutboundMessagingPool pool = channelManagers.get(address); + OutboundConnections pool = channelManagers.get(address); if (pool == null) return false; - return pool.getConnection(messageOut).isConnected(); + return pool.connectionFor(messageOut).isConnected(); } - public static boolean isEncryptedConnection(InetAddressAndPort address) + public void listen() { - IEndpointSnitch snitch = DatabaseDescriptor.getEndpointSnitch(); - switch (DatabaseDescriptor.getInternodeMessagingEncyptionOptions().internode_encryption) - { - case none: - return false; // if nothing needs to be encrypted then return immediately. - case all: - break; - case dc: - if (snitch.getDatacenter(address).equals(snitch.getLocalDatacenter())) - return false; - break; - case rack: - // for rack then check if the DC's are the same. - if (snitch.getRack(address).equals(snitch.getLocalRack()) - && snitch.getDatacenter(address).equals(snitch.getLocalDatacenter())) - return false; - break; - } - return true; + inboundSockets.open(); } - @Override - public void reloadSslCertificates() throws IOException + public void waitUntilListening() throws InterruptedException { - final ServerEncryptionOptions serverOpts = DatabaseDescriptor.getInternodeMessagingEncyptionOptions(); - final EncryptionOptions clientOpts = DatabaseDescriptor.getNativeProtocolEncryptionOptions(); - SSLFactory.validateSslCerts(serverOpts, clientOpts); - SSLFactory.checkCertFilesForHotReloading(serverOpts, clientOpts); + inboundSockets.open().await(); } } diff --git a/src/java/org/apache/cassandra/net/MessagingServiceMBeanImpl.java b/src/java/org/apache/cassandra/net/MessagingServiceMBeanImpl.java new file mode 100644 index 000000000000..b48ae1c625e1 --- /dev/null +++ b/src/java/org/apache/cassandra/net/MessagingServiceMBeanImpl.java @@ -0,0 +1,304 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.cassandra.net; + +import java.io.IOException; +import java.net.UnknownHostException; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; + +import org.apache.cassandra.config.DatabaseDescriptor; +import org.apache.cassandra.config.EncryptionOptions; +import org.apache.cassandra.locator.InetAddressAndPort; +import org.apache.cassandra.metrics.InternodeOutboundMetrics; +import org.apache.cassandra.metrics.MessagingMetrics; +import org.apache.cassandra.security.SSLFactory; +import org.apache.cassandra.utils.MBeanWrapper; + +public class MessagingServiceMBeanImpl implements MessagingServiceMBean +{ + public static final String MBEAN_NAME = "org.apache.cassandra.net:type=MessagingService"; + + // we use CHM deliberately instead of NBHM, as both are non-blocking for readers (which this map mostly is used for) + // and CHM permits prompter GC + public final ConcurrentMap channelManagers = new ConcurrentHashMap<>(); + public final ConcurrentMap messageHandlers = new ConcurrentHashMap<>(); + + public final EndpointMessagingVersions versions = new EndpointMessagingVersions(); + public final MessagingMetrics metrics = new MessagingMetrics(); + + MessagingServiceMBeanImpl(boolean testOnly) + { + if (!testOnly) + { + MBeanWrapper.instance.registerMBean(this, MBEAN_NAME); + metrics.scheduleLogging(); + } + } + + @Override + public Map getLargeMessagePendingTasks() + { + Map pendingTasks = new HashMap<>(channelManagers.size()); + for (Map.Entry entry : channelManagers.entrySet()) + pendingTasks.put(entry.getKey().toString(false), entry.getValue().large.pendingCount()); + return pendingTasks; + } + + @Override + public Map getLargeMessageCompletedTasks() + { + Map completedTasks = new HashMap<>(channelManagers.size()); + for (Map.Entry entry : channelManagers.entrySet()) + completedTasks.put(entry.getKey().toString(false), entry.getValue().large.sentCount()); + return completedTasks; + } + + @Override + public Map getLargeMessageDroppedTasks() + { + Map droppedTasks = new HashMap<>(channelManagers.size()); + for (Map.Entry entry : channelManagers.entrySet()) + droppedTasks.put(entry.getKey().toString(false), entry.getValue().large.dropped()); + return droppedTasks; + } + + @Override + public Map getSmallMessagePendingTasks() + { + Map pendingTasks = new HashMap<>(channelManagers.size()); + for (Map.Entry entry : channelManagers.entrySet()) + pendingTasks.put(entry.getKey().toString(false), entry.getValue().small.pendingCount()); + return pendingTasks; + } + + @Override + public Map getSmallMessageCompletedTasks() + { + Map completedTasks = new HashMap<>(channelManagers.size()); + for (Map.Entry entry : channelManagers.entrySet()) + completedTasks.put(entry.getKey().toString(false), entry.getValue().small.sentCount()); + return completedTasks; + } + + @Override + public Map getSmallMessageDroppedTasks() + { + Map droppedTasks = new HashMap<>(channelManagers.size()); + for (Map.Entry entry : channelManagers.entrySet()) + droppedTasks.put(entry.getKey().toString(false), entry.getValue().small.dropped()); + return droppedTasks; + } + + @Override + public Map getGossipMessagePendingTasks() + { + Map pendingTasks = new HashMap<>(channelManagers.size()); + for (Map.Entry entry : channelManagers.entrySet()) + pendingTasks.put(entry.getKey().toString(false), entry.getValue().urgent.pendingCount()); + return pendingTasks; + } + + @Override + public Map getGossipMessageCompletedTasks() + { + Map completedTasks = new HashMap<>(channelManagers.size()); + for (Map.Entry entry : channelManagers.entrySet()) + completedTasks.put(entry.getKey().toString(false), entry.getValue().urgent.sentCount()); + return completedTasks; + } + + @Override + public Map getGossipMessageDroppedTasks() + { + Map droppedTasks = new HashMap<>(channelManagers.size()); + for (Map.Entry entry : channelManagers.entrySet()) + droppedTasks.put(entry.getKey().toString(false), entry.getValue().urgent.dropped()); + return droppedTasks; + } + + @Override + public Map getLargeMessagePendingTasksWithPort() + { + Map pendingTasks = new HashMap<>(channelManagers.size()); + for (Map.Entry entry : channelManagers.entrySet()) + pendingTasks.put(entry.getKey().toString(), entry.getValue().large.pendingCount()); + return pendingTasks; + } + + @Override + public Map getLargeMessageCompletedTasksWithPort() + { + Map completedTasks = new HashMap<>(channelManagers.size()); + for (Map.Entry entry : channelManagers.entrySet()) + completedTasks.put(entry.getKey().toString(), entry.getValue().large.sentCount()); + return completedTasks; + } + + @Override + public Map getLargeMessageDroppedTasksWithPort() + { + Map droppedTasks = new HashMap<>(channelManagers.size()); + for (Map.Entry entry : channelManagers.entrySet()) + droppedTasks.put(entry.getKey().toString(), entry.getValue().large.dropped()); + return droppedTasks; + } + + @Override + public Map getSmallMessagePendingTasksWithPort() + { + Map pendingTasks = new HashMap<>(channelManagers.size()); + for (Map.Entry entry : channelManagers.entrySet()) + pendingTasks.put(entry.getKey().toString(), entry.getValue().small.pendingCount()); + return pendingTasks; + } + + @Override + public Map getSmallMessageCompletedTasksWithPort() + { + Map completedTasks = new HashMap<>(channelManagers.size()); + for (Map.Entry entry : channelManagers.entrySet()) + completedTasks.put(entry.getKey().toString(), entry.getValue().small.sentCount()); + return completedTasks; + } + + @Override + public Map getSmallMessageDroppedTasksWithPort() + { + Map droppedTasks = new HashMap<>(channelManagers.size()); + for (Map.Entry entry : channelManagers.entrySet()) + droppedTasks.put(entry.getKey().toString(), entry.getValue().small.dropped()); + return droppedTasks; + } + + @Override + public Map getGossipMessagePendingTasksWithPort() + { + Map pendingTasks = new HashMap<>(channelManagers.size()); + for (Map.Entry entry : channelManagers.entrySet()) + pendingTasks.put(entry.getKey().toString(), entry.getValue().urgent.pendingCount()); + return pendingTasks; + } + + @Override + public Map getGossipMessageCompletedTasksWithPort() + { + Map completedTasks = new HashMap<>(channelManagers.size()); + for (Map.Entry entry : channelManagers.entrySet()) + completedTasks.put(entry.getKey().toString(), entry.getValue().urgent.sentCount()); + return completedTasks; + } + + @Override + public Map getGossipMessageDroppedTasksWithPort() + { + Map droppedTasks = new HashMap<>(channelManagers.size()); + for (Map.Entry entry : channelManagers.entrySet()) + droppedTasks.put(entry.getKey().toString(), entry.getValue().urgent.dropped()); + return droppedTasks; + } + + @Override + public Map getDroppedMessages() + { + return metrics.getDroppedMessages(); + } + + @Override + public long getTotalTimeouts() + { + return InternodeOutboundMetrics.totalExpiredCallbacks.getCount(); + } + + // these are not messages that time out on sending, but callbacks that timedout without receiving a response + @Override + public Map getTimeoutsPerHost() + { + Map result = new HashMap<>(channelManagers.size()); + for (Map.Entry entry : channelManagers.entrySet()) + { + String ip = entry.getKey().toString(false); + long recent = entry.getValue().expiredCallbacks(); + result.put(ip, recent); + } + return result; + } + + // these are not messages that time out on sending, but callbacks that timedout without receiving a response + @Override + public Map getTimeoutsPerHostWithPort() + { + Map result = new HashMap<>(channelManagers.size()); + for (Map.Entry entry : channelManagers.entrySet()) + { + String ip = entry.getKey().toString(); + long recent = entry.getValue().expiredCallbacks(); + result.put(ip, recent); + } + return result; + } + + @Override + public Map getBackPressurePerHost() + { + Map map = new HashMap<>(channelManagers.size()); + for (Map.Entry entry : channelManagers.entrySet()) + map.put(entry.getKey().toString(false), entry.getValue().getBackPressureState().getBackPressureRateLimit()); + + return map; + } + + @Override + public Map getBackPressurePerHostWithPort() + { + Map map = new HashMap<>(channelManagers.size()); + for (Map.Entry entry : channelManagers.entrySet()) + map.put(entry.getKey().toString(false), entry.getValue().getBackPressureState().getBackPressureRateLimit()); + + return map; + } + + @Override + public void setBackPressureEnabled(boolean enabled) + { + DatabaseDescriptor.setBackPressureEnabled(enabled); + } + + @Override + public boolean isBackPressureEnabled() + { + return DatabaseDescriptor.backPressureEnabled(); + } + + @Override + public void reloadSslCertificates() throws IOException + { + final EncryptionOptions.ServerEncryptionOptions serverOpts = DatabaseDescriptor.getInternodeMessagingEncyptionOptions(); + final EncryptionOptions clientOpts = DatabaseDescriptor.getNativeProtocolEncryptionOptions(); + SSLFactory.validateSslCerts(serverOpts, clientOpts); + SSLFactory.checkCertFilesForHotReloading(serverOpts, clientOpts); + } + + @Override + public int getVersion(String address) throws UnknownHostException + { + return versions.get(address); + } +} diff --git a/src/java/org/apache/cassandra/net/PongMessage.java b/src/java/org/apache/cassandra/net/NoPayload.java similarity index 59% rename from src/java/org/apache/cassandra/net/PongMessage.java rename to src/java/org/apache/cassandra/net/NoPayload.java index bb89cdf51368..3b2b1772a8cc 100644 --- a/src/java/org/apache/cassandra/net/PongMessage.java +++ b/src/java/org/apache/cassandra/net/NoPayload.java @@ -15,36 +15,39 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package org.apache.cassandra.net; -import java.io.IOException; - import org.apache.cassandra.io.IVersionedSerializer; import org.apache.cassandra.io.util.DataInputPlus; import org.apache.cassandra.io.util.DataOutputPlus; -public class PongMessage +/** + * Empty message payload - primarily used for responses. + * + * Prefer this singleton to writing one-off specialised classes. + */ +public class NoPayload { - public static final PongMessage instance = new PongMessage(); - public static IVersionedSerializer serializer = new PongMessage.PongMessageSerializer(); + public static final NoPayload noPayload = new NoPayload(); - private PongMessage() - { } + private NoPayload() {} - public static class PongMessageSerializer implements IVersionedSerializer + public static final IVersionedSerializer serializer = new IVersionedSerializer() { - public void serialize(PongMessage t, DataOutputPlus out, int version) throws IOException - { } + public void serialize(NoPayload noPayload, DataOutputPlus out, int version) + { + if (noPayload != NoPayload.noPayload) + throw new IllegalArgumentException(); + } - public PongMessage deserialize(DataInputPlus in, int version) throws IOException + public NoPayload deserialize(DataInputPlus in, int version) { - return instance; + return noPayload; } - public long serializedSize(PongMessage t, int version) + public long serializedSize(NoPayload noPayload, int version) { return 0; } - } + }; } diff --git a/src/java/org/apache/cassandra/net/NoSizeEstimator.java b/src/java/org/apache/cassandra/net/NoSizeEstimator.java new file mode 100644 index 000000000000..848d4f566269 --- /dev/null +++ b/src/java/org/apache/cassandra/net/NoSizeEstimator.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.cassandra.net; + +import io.netty.channel.MessageSizeEstimator; + +/** + * We want to manage the bytes we have in-flight, so this class asks Netty not to by returning zero for every object. + */ +class NoSizeEstimator implements MessageSizeEstimator, MessageSizeEstimator.Handle +{ + public static final NoSizeEstimator instance = new NoSizeEstimator(); + private NoSizeEstimator() {} + public Handle newHandle() { return this; } + public int size(Object o) { return 0; } +} diff --git a/src/java/org/apache/cassandra/net/OutboundConnection.java b/src/java/org/apache/cassandra/net/OutboundConnection.java new file mode 100644 index 000000000000..63b909c58a13 --- /dev/null +++ b/src/java/org/apache/cassandra/net/OutboundConnection.java @@ -0,0 +1,1729 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.net; + +import java.io.IOException; +import java.net.ConnectException; +import java.net.InetSocketAddress; +import java.nio.channels.ClosedChannelException; +import java.util.Objects; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicLongFieldUpdater; +import java.util.concurrent.atomic.AtomicReference; +import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; + +import javax.annotation.Nullable; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.util.concurrent.Uninterruptibles; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.EventLoop; +import io.netty.channel.unix.Errors; +import io.netty.util.concurrent.Future; +import io.netty.util.concurrent.Promise; +import io.netty.util.concurrent.PromiseNotifier; +import io.netty.util.concurrent.SucceededFuture; +import org.apache.cassandra.config.DatabaseDescriptor; +import org.apache.cassandra.io.util.DataOutputBufferFixed; +import org.apache.cassandra.net.OutboundConnectionInitiator.Result.MessagingSuccess; +import org.apache.cassandra.tracing.Tracing; +import org.apache.cassandra.utils.FBUtilities; +import org.apache.cassandra.utils.JVMStabilityInspector; +import org.apache.cassandra.utils.NoSpamLogger; + +import static java.lang.Math.max; +import static java.lang.Math.min; +import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static org.apache.cassandra.net.MessagingService.current_version; +import static org.apache.cassandra.net.OutboundConnectionInitiator.*; +import static org.apache.cassandra.net.OutboundConnections.LARGE_MESSAGE_THRESHOLD; +import static org.apache.cassandra.net.ResourceLimits.*; +import static org.apache.cassandra.net.ResourceLimits.Outcome.*; +import static org.apache.cassandra.net.SocketFactory.*; +import static org.apache.cassandra.utils.MonotonicClock.approxTime; +import static org.apache.cassandra.utils.Throwables.isCausedBy; + +/** + * Represents a connection type to a peer, and handles the state transistions on the connection and the netty {@link Channel}. + * The underlying socket is not opened until explicitly requested (by sending a message). + * + * TODO: complete this description + * + * Aside from a few administrative methods, the main entry point to sending a message is {@link #enqueue(Message)}. + * Any thread may send a message (enqueueing it to {@link #queue}), but only one thread may consume messages from this + * queue. There is a single delivery thread - either the event loop, or a companion thread - that has logical ownership + * of the queue, but other threads may temporarily take ownership in order to perform book keeping, pruning, etc., + * to ensure system stability. + * + * {@link Delivery#run()} is the main entry point for consuming messages from the queue, and executes either on the event + * loop or on a non-dedicated companion thread. This processing is activated via {@link Delivery#execute()}. + * + * Almost all internal state maintenance on this class occurs on the eventLoop, a single threaded executor which is + * assigned in the constructor. Further details are outlined below in the class. Some behaviours require coordination + * between the eventLoop and the companion thread (if any). Some minimal set of behaviours are permitted to occur on + * producers to ensure the connection remains healthy and does not overcommit resources. + * + * All methods are safe to invoke from any thread unless otherwise stated. + */ +@SuppressWarnings({ "WeakerAccess", "FieldMayBeFinal", "NonAtomicOperationOnVolatileField", "SameParameterValue" }) +public class OutboundConnection +{ + static final Logger logger = LoggerFactory.getLogger(OutboundConnection.class); + private static final NoSpamLogger noSpamLogger = NoSpamLogger.getLogger(logger, 30L, TimeUnit.SECONDS); + + private static final AtomicLongFieldUpdater submittedUpdater = AtomicLongFieldUpdater.newUpdater(OutboundConnection.class, "submittedCount"); + private static final AtomicLongFieldUpdater pendingCountAndBytesUpdater = AtomicLongFieldUpdater.newUpdater(OutboundConnection.class, "pendingCountAndBytes"); + private static final AtomicLongFieldUpdater overloadedCountUpdater = AtomicLongFieldUpdater.newUpdater(OutboundConnection.class, "overloadedCount"); + private static final AtomicLongFieldUpdater overloadedBytesUpdater = AtomicLongFieldUpdater.newUpdater(OutboundConnection.class, "overloadedBytes"); + private static final AtomicReferenceFieldUpdater closingUpdater = AtomicReferenceFieldUpdater.newUpdater(OutboundConnection.class, Future.class, "closing"); + private static final AtomicReferenceFieldUpdater scheduledCloseUpdater = AtomicReferenceFieldUpdater.newUpdater(OutboundConnection.class, Future.class, "scheduledClose"); + + private final EventLoop eventLoop; + private final Delivery delivery; + + private final OutboundMessageCallbacks callbacks; + private final OutboundDebugCallbacks debug; + private final OutboundMessageQueue queue; + /** the number of bytes we permit to queue to the network without acquiring any shared resource permits */ + private final long pendingCapacityInBytes; + /** the number of messages and bytes queued for flush to the network, + * including those that are being flushed but have not been completed, + * packed into a long (top 20 bits for count, bottom 42 for bytes)*/ + private volatile long pendingCountAndBytes = 0; + /** global shared limits that we use only if our local limits are exhausted; + * we allocate from here whenever queueSize > queueCapacity */ + private final EndpointAndGlobal reserveCapacityInBytes; + + private volatile long submittedCount = 0; // updated with cas + private volatile long overloadedCount = 0; // updated with cas + private volatile long overloadedBytes = 0; // updated with cas + private long expiredCount = 0; // updated with queue lock held + private long expiredBytes = 0; // updated with queue lock held + private long errorCount = 0; // updated only by delivery thread + private long errorBytes = 0; // updated by delivery thread only + private long sentCount; // updated by delivery thread only + private long sentBytes; // updated by delivery thread only + private long successfulConnections; // updated by event loop only + private long connectionAttempts; // updated by event loop only + + private static final int pendingByteBits = 42; + private static boolean isMaxPendingCount(long pendingCountAndBytes) + { + return (pendingCountAndBytes & (-1L << pendingByteBits)) == (-1L << pendingByteBits); + } + + private static int pendingCount(long pendingCountAndBytes) + { + return (int) (pendingCountAndBytes >>> pendingByteBits); + } + + private static long pendingBytes(long pendingCountAndBytes) + { + return pendingCountAndBytes & (-1L >>> (64 - pendingByteBits)); + } + + private static long pendingCountAndBytes(long pendingCount, long pendingBytes) + { + return (pendingCount << pendingByteBits) | pendingBytes; + } + + private final ConnectionType type; + + /** + * Contains the base settings for this connection, _including_ any defaults filled in. + * + */ + private OutboundConnectionSettings template; + + private static class State + { + static final State CLOSED = new State(Kind.CLOSED); + + enum Kind { ESTABLISHED, CONNECTING, DORMANT, CLOSED } + + final Kind kind; + + State(Kind kind) + { + this.kind = kind; + } + + boolean isEstablished() { return kind == Kind.ESTABLISHED; } + boolean isConnecting() { return kind == Kind.CONNECTING; } + boolean isDisconnected() { return kind == Kind.CONNECTING || kind == Kind.DORMANT; } + boolean isClosed() { return kind == Kind.CLOSED; } + + Established established() { return (Established) this; } + Connecting connecting() { return (Connecting) this; } + Disconnected disconnected() { return (Disconnected) this; } + } + + /** + * We have successfully negotiated a channel, and believe it to still be valid. + * + * Before using this, we should check isConnected() to check the Channel hasn't + * become invalid. + */ + private static class Established extends State + { + final int messagingVersion; + final Channel channel; + final FrameEncoder.PayloadAllocator payloadAllocator; + final OutboundConnectionSettings settings; + + Established(int messagingVersion, Channel channel, FrameEncoder.PayloadAllocator payloadAllocator, OutboundConnectionSettings settings) + { + super(Kind.ESTABLISHED); + this.messagingVersion = messagingVersion; + this.channel = channel; + this.payloadAllocator = payloadAllocator; + this.settings = settings; + } + + boolean isConnected() { return channel.isOpen(); } + } + + private static class Disconnected extends State + { + /** Periodic message expiry scheduled while we are disconnected; this will be cancelled and cleared each time we connect */ + final Future maintenance; + Disconnected(Kind kind, Future maintenance) + { + super(kind); + this.maintenance = maintenance; + } + + public static Disconnected dormant(Future maintenance) + { + return new Disconnected(Kind.DORMANT, maintenance); + } + } + + private static class Connecting extends Disconnected + { + /** + * Currently (or scheduled to) (re)connect; this may be cancelled (if closing) or waited on (for delivery) + * + * - The work managed by this future is partially performed asynchronously, not necessarily on the eventLoop. + * - It is only completed on the eventLoop + * - It may not be executing, but might be scheduled to be submitted if {@link #scheduled} is not null + */ + final Future> attempt; + + /** + * If we are retrying to connect with some delay, this represents the scheduled inititation of another attempt + */ + @Nullable + final Future scheduled; + + /** + * true iff we are retrying to connect after some failure (immediately or following a delay) + */ + final boolean isFailingToConnect; + + Connecting(Disconnected previous, Future> attempt) + { + this(previous, attempt, null); + } + + Connecting(Disconnected previous, Future> attempt, Future scheduled) + { + super(Kind.CONNECTING, previous.maintenance); + this.attempt = attempt; + this.scheduled = scheduled; + this.isFailingToConnect = scheduled != null || (previous.isConnecting() && previous.connecting().isFailingToConnect); + } + + /** + * Cancel the connection attempt + * + * No cleanup is needed here, as {@link #attempt} is only completed on the eventLoop, + * so we have either already invoked the callbacks and are no longer in {@link #state}, + * or the {@link OutboundConnectionInitiator} will handle our successful cancellation + * when it comes to complete, by closing the channel (if we could not cancel it before then) + */ + void cancel() + { + if (scheduled != null) + scheduled.cancel(true); + + // we guarantee that attempt is only ever completed by the eventLoop + boolean cancelled = attempt.cancel(true); + assert cancelled; + } + } + + private volatile State state; + + /** The connection is being permanently closed */ + private volatile Future closing; + /** The connection is being permanently closed in the near future */ + private volatile Future scheduledClose; + + OutboundConnection(ConnectionType type, OutboundConnectionSettings settings, EndpointAndGlobal reserveCapacityInBytes) + { + this.template = settings.withDefaults(ConnectionCategory.MESSAGING); + this.type = type; + this.eventLoop = template.socketFactory.defaultGroup().next(); + this.pendingCapacityInBytes = template.applicationSendQueueCapacityInBytes; + this.reserveCapacityInBytes = reserveCapacityInBytes; + this.callbacks = template.callbacks; + this.debug = template.debug; + this.queue = new OutboundMessageQueue(this::onExpired); + this.delivery = type == ConnectionType.LARGE_MESSAGES + ? new LargeMessageDelivery(template.socketFactory.synchronousWorkExecutor) + : new EventLoopDelivery(); + setDisconnected(); + } + + /** + * This is the main entry point for enqueuing a message to be sent to the remote peer. + */ + public void enqueue(Message message) throws ClosedChannelException + { + if (isClosing()) + throw new ClosedChannelException(); + + final int canonicalSize = canonicalSize(message); + if (canonicalSize > DatabaseDescriptor.getInternodeMaxMessageSizeInBytes()) + throw new Message.OversizedMessageException(canonicalSize); + + submittedUpdater.incrementAndGet(this); + switch (acquireCapacity(canonicalSize)) + { + case INSUFFICIENT_ENDPOINT: + // if we're overloaded to one endpoint, we may be accumulating expirable messages, so + // attempt an expiry to see if this makes room for our newer message. + // this is an optimisation only; messages will be expired on ~100ms cycle, and by Delivery when it runs + if (queue.maybePruneExpired() && SUCCESS == acquireCapacity(canonicalSize)) + break; + case INSUFFICIENT_GLOBAL: + onOverloaded(message); + return; + } + + queue.add(message); + delivery.execute(); + + // we might race with the channel closing; if this happens, to ensure this message eventually arrives + // we need to remove ourselves from the queue and throw a ClosedChannelException, so that another channel + // can be opened in our place to try and send on. + if (isClosing() && queue.remove(message)) + { + releaseCapacity(1, canonicalSize); + throw new ClosedChannelException(); + } + } + + /** + * Try to acquire the necessary resource permits for a number of pending bytes for this connection. + * + * Since the owner limit is shared amongst multiple connections, our semantics cannot be super trivial. + * Were they per-connection, we could simply perform an atomic increment of the queue size, then + * allocate any excess we need in the reserve, and on release free everything we see from both. + * Since we are coordinating two independent atomic variables we have to track every byte we allocate in reserve + * and ensure it is matched by a corresponding released byte. We also need to be sure we do not permit another + * releasing thread to release reserve bytes we have not yet - and may never - actually reserve. + * + * As such, we have to first check if we would need reserve bytes, then allocate them *before* we increment our + * queue size. We only increment the queue size if the reserve bytes are definitely not needed, or we could first + * obtain them. If in the process of obtaining any reserve bytes the queue size changes, we have some bytes that are + * reserved for us, but may be a different number to that we need. So we must continue to track these. + * + * In the happy path, this is still efficient as we simply CAS + */ + private Outcome acquireCapacity(long bytes) + { + return acquireCapacity(1, bytes); + } + + private Outcome acquireCapacity(long count, long bytes) + { + long increment = pendingCountAndBytes(count, bytes); + long unusedClaimedReserve = 0; + Outcome outcome = null; + loop: while (true) + { + long current = pendingCountAndBytes; + if (isMaxPendingCount(current)) + { + outcome = INSUFFICIENT_ENDPOINT; + break; + } + + long next = current + increment; + if (pendingBytes(next) <= pendingCapacityInBytes) + { + if (pendingCountAndBytesUpdater.compareAndSet(this, current, next)) + { + outcome = SUCCESS; + break; + } + continue; + } + + State state = this.state; + if (state.isConnecting() && state.connecting().isFailingToConnect) + { + outcome = INSUFFICIENT_ENDPOINT; + break; + } + + long requiredReserve = min(bytes, pendingBytes(next) - pendingCapacityInBytes); + if (unusedClaimedReserve < requiredReserve) + { + long extraGlobalReserve = requiredReserve - unusedClaimedReserve; + switch (outcome = reserveCapacityInBytes.tryAllocate(extraGlobalReserve)) + { + case INSUFFICIENT_ENDPOINT: + case INSUFFICIENT_GLOBAL: + break loop; + case SUCCESS: + unusedClaimedReserve += extraGlobalReserve; + } + } + + if (pendingCountAndBytesUpdater.compareAndSet(this, current, next)) + { + unusedClaimedReserve -= requiredReserve; + break; + } + } + + if (unusedClaimedReserve > 0) + reserveCapacityInBytes.release(unusedClaimedReserve); + + return outcome; + } + + /** + * Mark a number of pending bytes as flushed to the network, releasing their capacity for new outbound messages. + */ + private void releaseCapacity(long count, long bytes) + { + long decrement = pendingCountAndBytes(count, bytes); + long prev = pendingCountAndBytesUpdater.getAndAdd(this, -decrement); + if (pendingBytes(prev) > pendingCapacityInBytes) + { + long excess = min(pendingBytes(prev) - pendingCapacityInBytes, bytes); + reserveCapacityInBytes.release(excess); + } + } + + private void onOverloaded(Message message) + { + overloadedCountUpdater.incrementAndGet(this); + overloadedBytesUpdater.addAndGet(this, canonicalSize(message)); + noSpamLogger.warn("{} overloaded; dropping {} message (queue: {} local, {} endpoint, {} global)", + id(), + FBUtilities.prettyPrintMemory(canonicalSize(message)), + FBUtilities.prettyPrintMemory(pendingBytes()), + FBUtilities.prettyPrintMemory(reserveCapacityInBytes.endpoint.using()), + FBUtilities.prettyPrintMemory(reserveCapacityInBytes.global.using())); + callbacks.onOverloaded(message, template.to); + } + + /** + * Take any necessary cleanup action after a message has been selected to be discarded from the queue. + * + * Only to be invoked while holding OutboundMessageQueue.WithLock + */ + private boolean onExpired(Message message) + { + releaseCapacity(1, canonicalSize(message)); + expiredCount += 1; + expiredBytes += canonicalSize(message); + noSpamLogger.warn("{} dropping message of type {} whose timeout expired before reaching the network", id(), message.verb()); + callbacks.onExpired(message, template.to); + return true; + } + + /** + * Take any necessary cleanup action after a message has been selected to be discarded from the queue. + * + * Only to be invoked by the delivery thread + */ + private void onFailedSerialize(Message message, int messagingVersion, int bytesWrittenToNetwork, Throwable t) + { + JVMStabilityInspector.inspectThrowable(t, false); + releaseCapacity(1, canonicalSize(message)); + errorCount += 1; + errorBytes += message.serializedSize(messagingVersion); + logger.warn("{} dropping message of type {} due to error", id(), message.verb(), t); + callbacks.onFailedSerialize(message, template.to, messagingVersion, bytesWrittenToNetwork, t); + } + + /** + * Take any necessary cleanup action after a message has been selected to be discarded from the queue on close. + * Note that this is only for messages that were queued prior to closing without graceful flush, OR + * for those that are unceremoniously dropped when we decide close has been trying to complete for too long. + */ + private void onClosed(Message message) + { + releaseCapacity(1, canonicalSize(message)); + callbacks.onDiscardOnClose(message, template.to); + } + + /** + * Delivery bundles the following: + * + * - the work that is necessary to actually deliver messages safely, and handle any exceptional states + * - the ability to schedule delivery for some time in the future + * - the ability to schedule some non-delivery work to happen some time in the future, that is guaranteed + * NOT to coincide with delivery for its duration, including any data that is being flushed (e.g. for closing channels) + * - this feature is *not* efficient, and should only be used for infrequent operations + */ + private abstract class Delivery extends AtomicInteger implements Runnable + { + final ExecutorService executor; + + // the AtomicInteger we extend always contains some combination of these bit flags, representing our current run state + + /** Not running, and will not be scheduled again until transitioned to a new state */ + private static final int STOPPED = 0; + /** Currently executing (may only be scheduled to execute, or may be about to terminate); + * will stop at end of this run, without rescheduling */ + private static final int EXECUTING = 1; + /** Another execution has been requested; a new execution will begin some time after this state is taken */ + private static final int EXECUTE_AGAIN = 2; + /** We are currently executing and will submit another execution before we terminate */ + private static final int EXECUTING_AGAIN = EXECUTING | EXECUTE_AGAIN; + /** Will begin a new execution some time after this state is taken, but only once some condition is met. + * This state will initially be taken in tandem with EXECUTING, but if delivery completes without clearing + * the state, the condition will be held on its own until {@link #executeAgain} is invoked */ + private static final int WAITING_TO_EXECUTE = 4; + + /** + * Force all task execution to stop, once any currently in progress work is completed + */ + private volatile boolean terminated; + + /** + * Is there asynchronous delivery work in progress. + * + * This temporarily prevents any {@link #stopAndRun} work from being performed. + * Once both inProgress and stopAndRun are set we perform no more delivery work until one is unset, + * to ensure we eventually run stopAndRun. + * + * This should be updated and read only on the Delivery thread. + */ + private boolean inProgress = false; + + /** + * Request a task's execution while there is no delivery work in progress. + * + * This is to permit cleanly tearing down a connection without interrupting any messages that might be in flight. + * If stopAndRun is set, we should not enter doRun() until a corresponding setInProgress(false) occurs. + */ + final AtomicReference stopAndRun = new AtomicReference<>(); + + Delivery(ExecutorService executor) + { + this.executor = executor; + } + + /** + * Ensure that any messages or stopAndRun that were queued prior to this invocation will be seen by at least + * one future invocation of the delivery task, unless delivery has already been terminated. + */ + public void execute() + { + if (get() < EXECUTE_AGAIN && STOPPED == getAndUpdate(i -> i == STOPPED ? EXECUTING: i | EXECUTE_AGAIN)) + executor.execute(this); + } + + private boolean isExecuting(int state) + { + return 0 != (state & EXECUTING); + } + + /** + * This method is typically invoked after WAITING_TO_EXECUTE is set. + * + * However WAITING_TO_EXECUTE does not need to be set; all this method needs to ensure is that + * delivery unconditionally performs one new execution promptly. + */ + void executeAgain() + { + // if we are already executing, set EXECUTING_AGAIN and leave scheduling to the currently running one. + // otherwise, set ourselves unconditionally to EXECUTING and schedule ourselves immediately + if (!isExecuting(getAndUpdate(i -> !isExecuting(i) ? EXECUTING : EXECUTING_AGAIN))) + executor.execute(this); + } + + /** + * Invoke this when we cannot make further progress now, but we guarantee that we will execute later when we can. + * This simply communicates to {@link #run} that we should not schedule ourselves again, just unset the EXECUTING bit. + */ + void promiseToExecuteLater() + { + set(EXECUTING | WAITING_TO_EXECUTE); + } + + /** + * Called when exiting {@link #run} to schedule another run if necessary. + * + * If we are currently executing, we only reschedule if the present state is EXECUTING_AGAIN. + * If this is the case, we clear the EXECUTE_AGAIN bit (setting ourselves to EXECUTING), and reschedule. + * Otherwise, we clear the EXECUTING bit and terminate, which will set us to either STOPPED or WAITING_TO_EXECUTE + * (or possibly WAITING_TO_EXECUTE | EXECUTE_AGAIN, which is logically the same as WAITING_TO_EXECUTE) + */ + private void maybeExecuteAgain() + { + if (EXECUTING_AGAIN == getAndUpdate(i -> i == EXECUTING_AGAIN ? EXECUTING : (i & ~EXECUTING))) + executor.execute(this); + } + + /** + * No more tasks or delivery will be executed, once any in progress complete. + */ + public void terminate() + { + terminated = true; + } + + /** + * Only to be invoked by the Delivery task. + * + * If true, indicates that we have begun asynchronous delivery work, so that + * we cannot safely stopAndRun until it completes. + * + * Once it completes, we ensure any stopAndRun task has a chance to execute + * by ensuring delivery is scheduled. + * + * If stopAndRun is also set, we should not enter doRun() until a corresponding + * setInProgress(false) occurs. + */ + void setInProgress(boolean inProgress) + { + boolean wasInProgress = this.inProgress; + this.inProgress = inProgress; + if (!inProgress && wasInProgress) + executeAgain(); + } + + /** + * Perform some delivery work. + * + * Must never be invoked directly, only via {@link #execute()} + */ + public void run() + { + /* do/while handling setup for {@link #doRun()}, and repeat invocations thereof */ + while (true) + { + if (terminated) + return; + + if (null != stopAndRun.get()) + { + // if we have an external request to perform, attempt it - if no async delivery is in progress + + if (inProgress) + { + // if we are in progress, we cannot do anything; + // so, exit and rely on setInProgress(false) executing us + // (which must happen later, since it must happen on this thread) + promiseToExecuteLater(); + break; + } + + stopAndRun.getAndSet(null).run(); + } + + State state = OutboundConnection.this.state; + if (!state.isEstablished() || !state.established().isConnected()) + { + // if we have messages yet to deliver, or a task to run, we need to reconnect and try again + // we try to reconnect before running another stopAndRun so that we do not infinite loop in close + if (hasPending() || null != stopAndRun.get()) + { + promiseToExecuteLater(); + requestConnect().addListener(f -> executeAgain()); + } + break; + } + + if (!doRun(state.established())) + break; + } + + maybeExecuteAgain(); + } + + /** + * @return true if we should run again immediately; + * always false for eventLoop executor, as want to service other channels + */ + abstract boolean doRun(Established established); + + /** + * Schedule a task to run later on the delivery thread while delivery is not in progress, + * i.e. there are no bytes in flight to the network buffer. + * + * Does not guarantee to run promptly if there is no current connection to the remote host. + * May wait until a new connection is established, or a connection timeout elapses, before executing. + * + * Update the shared atomic property containing work we want to interrupt message processing to perform, + * the invoke schedule() to be certain it gets run. + */ + void stopAndRun(Runnable run) + { + stopAndRun.accumulateAndGet(run, OutboundConnection::andThen); + execute(); + } + + /** + * Schedule a task to run on the eventLoop, guaranteeing that delivery will not occur while the task is performed. + */ + abstract void stopAndRunOnEventLoop(Runnable run); + + } + + /** + * Delivery that runs entirely on the eventLoop + * + * Since this has single threaded access to most of its environment, it can be simple and efficient, however + * it must also have bounded run time, and limit its resource consumption to ensure other channels serviced by the + * eventLoop can also make progress. + * + * This operates on modest buffers, no larger than the {@link OutboundConnections#LARGE_MESSAGE_THRESHOLD} and + * filling at most one at a time before writing (potentially asynchronously) to the socket. + * + * We track the number of bytes we have in flight, ensuring no more than a user-defined maximum at any one time. + */ + class EventLoopDelivery extends Delivery + { + private int flushingBytes; + private boolean isWritable = true; + + EventLoopDelivery() + { + super(eventLoop); + } + + /** + * {@link Delivery#doRun} + * + * Since we are on the eventLoop, in order to ensure other channels are serviced + * we never return true to request another run immediately. + * + * If there is more work to be done, we submit ourselves for execution once the eventLoop has time. + */ + @SuppressWarnings("resource") + boolean doRun(Established established) + { + if (!isWritable) + return false; + + // pendingBytes is updated before queue.size() (which triggers notEmpty, and begins delivery), + // so it is safe to use it here to exit delivery + // this number is inaccurate for old versions, but we don't mind terribly - we'll send at least one message, + // and get round to it eventually (though we could add a fudge factor for some room for older versions) + int maxSendBytes = (int) min(pendingBytes() - flushingBytes, LARGE_MESSAGE_THRESHOLD); + if (maxSendBytes == 0) + return false; + + OutboundConnectionSettings settings = established.settings; + int messagingVersion = established.messagingVersion; + + FrameEncoder.Payload sending = null; + int canonicalSize = 0; // number of bytes we must use for our resource accounting + int sendingBytes = 0; + int sendingCount = 0; + try (OutboundMessageQueue.WithLock withLock = queue.lockOrCallback(approxTime.now(), this::execute)) + { + if (withLock == null) + return false; // we failed to acquire the queue lock, so return; we will be scheduled again when the lock is available + + sending = established.payloadAllocator.allocate(true, maxSendBytes); + DataOutputBufferFixed out = new DataOutputBufferFixed(sending.buffer); + + Message next; + while ( null != (next = withLock.peek()) ) + { + try + { + int messageSize = next.serializedSize(messagingVersion); + + // actual message size for this version is larger than permitted maximum + if (messageSize > DatabaseDescriptor.getInternodeMaxMessageSizeInBytes()) + throw new Message.OversizedMessageException(messageSize); + + if (messageSize > sending.remaining()) + { + // if we don't have enough room to serialize the next message, we have either + // 1) run out of room after writing some messages successfully; this might mean that we are + // overflowing our highWaterMark, or that we have just filled our buffer + // 2) we have a message that is too large for this connection; this can happen if a message's + // size was calculated for the wrong messaging version when enqueued. + // In this case we want to write it anyway, so simply allocate a large enough buffer. + + if (sendingBytes > 0) + break; + + sending.release(); + sending = null; // set to null to prevent double-release if we fail to allocate our new buffer + sending = established.payloadAllocator.allocate(true, messageSize); + //noinspection IOResourceOpenedButNotSafelyClosed + out = new DataOutputBufferFixed(sending.buffer); + } + + Tracing.instance.traceOutgoingMessage(next, settings.connectTo); + Message.serializer.serialize(next, out, messagingVersion); + + if (sending.length() != sendingBytes + messageSize) + throw new InvalidSerializedSizeException(next.verb(), messageSize, sending.length() - sendingBytes); + + canonicalSize += canonicalSize(next); + sendingCount += 1; + sendingBytes += messageSize; + } + catch (Throwable t) + { + onFailedSerialize(next, messagingVersion, 0, t); + + assert sending != null; + // reset the buffer to ignore the message we failed to serialize + sending.trim(sendingBytes); + } + withLock.removeHead(next); + } + if (0 == sendingBytes) + return false; + + sending.finish(); + debug.onSendSmallFrame(sendingCount, sendingBytes); + ChannelFuture flushResult = AsyncChannelPromise.writeAndFlush(established.channel, sending); + sending = null; + + if (flushResult.isSuccess()) + { + sentCount += sendingCount; + sentBytes += sendingBytes; + debug.onSentSmallFrame(sendingCount, sendingBytes); + } + else + { + flushingBytes += canonicalSize; + setInProgress(true); + + boolean hasOverflowed = flushingBytes >= settings.flushHighWaterMark; + if (hasOverflowed) + { + isWritable = false; + promiseToExecuteLater(); + } + + int releaseBytesFinal = canonicalSize; + int sendingBytesFinal = sendingBytes; + int sendingCountFinal = sendingCount; + flushResult.addListener(future -> { + + releaseCapacity(sendingCountFinal, releaseBytesFinal); + flushingBytes -= releaseBytesFinal; + if (flushingBytes == 0) + setInProgress(false); + + if (!isWritable && flushingBytes <= settings.flushLowWaterMark) + { + isWritable = true; + executeAgain(); + } + + if (future.isSuccess()) + { + sentCount += sendingCountFinal; + sentBytes += sendingBytesFinal; + debug.onSentSmallFrame(sendingCountFinal, sendingBytesFinal); + } + else + { + errorCount += sendingCountFinal; + errorBytes += sendingBytesFinal; + invalidateChannel(established, future.cause()); + debug.onFailedSmallFrame(sendingCountFinal, sendingBytesFinal); + } + }); + canonicalSize = 0; + } + } + catch (Throwable t) + { + errorCount += sendingCount; + errorBytes += sendingBytes; + invalidateChannel(established, t); + } + finally + { + if (canonicalSize > 0) + releaseCapacity(sendingCount, canonicalSize); + + if (sending != null) + sending.release(); + + if (pendingBytes() > flushingBytes && isWritable) + execute(); + } + + return false; + } + + void stopAndRunOnEventLoop(Runnable run) + { + stopAndRun(run); + } + } + + /** + * Delivery that coordinates between the eventLoop and another (non-dedicated) thread + * + * This is to service messages that are too large to fully serialize on the eventLoop, as they could block + * prompt service of other requests. Since our serializers assume blocking IO, the easiest approach is to + * ensure a companion thread performs blocking IO that, under the hood, is serviced by async IO on the eventLoop. + * + * Most of the work here is handed off to {@link AsyncChannelOutputPlus}, with our main job being coordinating + * when and what we should run. + * + * To avoid allocating a huge number of threads across a cluster, we utilise the shared methods of {@link Delivery} + * to ensure that only one run() is actually scheduled to run at a time - this permits us to use any {@link ExecutorService} + * as a backing, with the number of threads defined only by the maximum concurrency needed to deliver all large messages. + * We use a shared caching {@link java.util.concurrent.ThreadPoolExecutor}, and rename the Threads that service + * our connection on entry and exit. + */ + class LargeMessageDelivery extends Delivery + { + static final int DEFAULT_BUFFER_SIZE = 32 * 1024; + + LargeMessageDelivery(ExecutorService executor) + { + super(executor); + } + + /** + * A simple wrapper of {@link Delivery#run} to set the current Thread name for the duration of its execution. + */ + public void run() + { + String threadName, priorThreadName = null; + try + { + priorThreadName = Thread.currentThread().getName(); + threadName = "Messaging-OUT-" + template.from() + "->" + template.to + '-' + type; + Thread.currentThread().setName(threadName); + + super.run(); + } + finally + { + if (priorThreadName != null) + Thread.currentThread().setName(priorThreadName); + } + } + + @SuppressWarnings({ "resource", "RedundantSuppression" }) // make eclipse warnings go away + boolean doRun(Established established) + { + Message send = queue.tryPoll(approxTime.now(), this::execute); + if (send == null) + return false; + + AsyncMessageOutputPlus out = null; + try + { + int messageSize = send.serializedSize(established.messagingVersion); + out = new AsyncMessageOutputPlus(established.channel, DEFAULT_BUFFER_SIZE, messageSize, established.payloadAllocator); + // actual message size for this version is larger than permitted maximum + if (messageSize > DatabaseDescriptor.getInternodeMaxMessageSizeInBytes()) + throw new Message.OversizedMessageException(messageSize); + + Tracing.instance.traceOutgoingMessage(send, established.settings.connectTo); + Message.serializer.serialize(send, out, established.messagingVersion); + + if (out.position() != messageSize) + throw new InvalidSerializedSizeException(send.verb(), messageSize, out.position()); + + out.close(); + sentCount += 1; + sentBytes += messageSize; + releaseCapacity(1, canonicalSize(send)); + return hasPending(); + } + catch (Throwable t) + { + boolean tryAgain = true; + + if (out != null) + { + out.discard(); + if (out.flushed() > 0 || + isCausedBy(t, cause -> isConnectionReset(cause) + || cause instanceof Errors.NativeIoException + || cause instanceof AsyncChannelOutputPlus.FlushException)) + { + // close the channel, and wait for eventLoop to execute + disconnectNow(established).awaitUninterruptibly(); + tryAgain = false; + try + { + // after closing, wait until we are signalled about the in flight writes; + // this ensures flushedToNetwork() is correct below + out.waitUntilFlushed(0, 0); + } + catch (Throwable ignore) + { + // irrelevant + } + } + } + + onFailedSerialize(send, established.messagingVersion, out == null ? 0 : (int) out.flushedToNetwork(), t); + return tryAgain; + } + } + + void stopAndRunOnEventLoop(Runnable run) + { + stopAndRun(() -> { + try + { + runOnEventLoop(run).await(); + } + catch (InterruptedException e) + { + throw new RuntimeException(e); + } + }); + } + } + + /* + * Size used for capacity enforcement purposes. Using current messaging version no matter what the peer's version is. + */ + private int canonicalSize(Message message) + { + return message.serializedSize(current_version); + } + + private void invalidateChannel(Established established, Throwable cause) + { + JVMStabilityInspector.inspectThrowable(cause, false); + + if (state != established) + return; // do nothing; channel already invalidated + + if (isCausedByConnectionReset(cause)) + logger.info("{} channel closed by provider", id(), cause); + else + logger.error("{} channel in potentially inconsistent state after error; closing", id(), cause); + + disconnectNow(established); + } + + /** + * Attempt to open a new channel to the remote endpoint. + * + * Most of the actual work is performed by OutboundConnectionInitiator, this method just manages + * our book keeping on either success or failure. + * + * This method is only to be invoked by the eventLoop, and the inner class' methods should only be evaluated by the eventtLoop + */ + Future initiate() + { + class Initiate + { + /** + * If we fail to connect, we want to try and connect again before any messages timeout. + * However, we update this each time to ensure we do not retry unreasonably often, and settle on a periodicity + * that might lead to timeouts in some aggressive systems. + */ + long retryRateMillis = DatabaseDescriptor.getMinRpcTimeout(MILLISECONDS) / 2; + + // our connection settings, possibly updated on retry + int messagingVersion = template.endpointToVersion().get(template.to); + OutboundConnectionSettings settings; + + /** + * If we failed for any reason, try again + */ + void onFailure(Throwable cause) + { + if (cause instanceof ConnectException) + noSpamLogger.info("{} failed to connect", id(), cause); + else + noSpamLogger.error("{} failed to connect", id(), cause); + + JVMStabilityInspector.inspectThrowable(cause, false); + + if (hasPending()) + { + Promise> result = new AsyncPromise<>(eventLoop); + state = new Connecting(state.disconnected(), result, eventLoop.schedule(() -> attempt(result), max(100, retryRateMillis), MILLISECONDS)); + retryRateMillis = min(1000, retryRateMillis * 2); + } + else + { + // this Initiate will be discarded + state = Disconnected.dormant(state.disconnected().maintenance); + } + } + + void onCompletedHandshake(Result result) + { + switch (result.outcome) + { + case SUCCESS: + // it is expected that close, if successful, has already cancelled us; so we do not need to worry about leaking connections + assert !state.isClosed(); + + MessagingSuccess success = result.success(); + debug.onConnect(success.messagingVersion, settings); + state.disconnected().maintenance.cancel(false); + + FrameEncoder.PayloadAllocator payloadAllocator = success.allocator; + Channel channel = success.channel; + Established established = new Established(messagingVersion, channel, payloadAllocator, settings); + state = established; + channel.pipeline().addLast("handleExceptionalStates", new ChannelInboundHandlerAdapter() { + @Override + public void channelInactive(ChannelHandlerContext ctx) + { + disconnectNow(established); + ctx.fireChannelInactive(); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) + { + try + { + invalidateChannel(established, cause); + } + catch (Throwable t) + { + logger.error("Unexpected exception in {}.exceptionCaught", this.getClass().getSimpleName(), t); + } + } + }); + ++successfulConnections; + + logger.info("{} successfully connected, version = {}, framing = {}, encryption = {}", + id(true), + success.messagingVersion, + settings.framing, + encryptionLogStatement(settings.encryption)); + break; + + case RETRY: + if (logger.isTraceEnabled()) + logger.trace("{} incorrect legacy peer version predicted; reconnecting", id()); + + // the messaging version we connected with was incorrect; try again with the one supplied by the remote host + messagingVersion = result.retry().withMessagingVersion; + settings.endpointToVersion.set(settings.to, messagingVersion); + + initiate(); + break; + + case INCOMPATIBLE: + // we cannot communicate with this peer given its messaging version; mark this as any other failure, and continue trying + Throwable t = new IOException(String.format("Incompatible peer: %s, messaging version: %s", + settings.to, result.incompatible().maxMessagingVersion)); + t.fillInStackTrace(); + onFailure(t); + break; + + default: + throw new AssertionError(); + } + } + + /** + * Initiate all the actions required to establish a working, valid connection. This includes + * opening the socket, negotiating the internode messaging handshake, and setting up the working + * Netty {@link Channel}. However, this method will not block for all those actions: it will only + * kick off the connection attempt, setting the @{link #connecting} future to track its completion. + * + * Note: this should only be invoked on the event loop. + */ + private void attempt(Promise> result) + { + ++connectionAttempts; + + settings = template; + if (messagingVersion > settings.acceptVersions.max) + messagingVersion = settings.acceptVersions.max; + + // ensure we connect to the correct SSL port + settings = settings.withLegacyPortIfNecessary(messagingVersion); + + initiateMessaging(eventLoop, type, settings, messagingVersion, result) + .addListener(future -> { + if (future.isCancelled()) + return; + if (future.isSuccess()) //noinspection unchecked + onCompletedHandshake((Result) future.getNow()); + else + onFailure(future.cause()); + }); + } + + Future> initiate() + { + Promise> result = new AsyncPromise<>(eventLoop); + state = new Connecting(state.disconnected(), result); + attempt(result); + return result; + } + } + + return new Initiate().initiate(); + } + + /** + * Returns a future that completes when we are _maybe_ reconnected. + * + * The connection attempt is guaranteed to have completed (successfully or not) by the time any listeners are invoked, + * so if a reconnection attempt is needed, it is already scheduled. + */ + private Future requestConnect() + { + // we may race with updates to this variable, but this is fine, since we only guarantee that we see a value + // that did at some point represent an active connection attempt - if it is stale, it will have been completed + // and the caller can retry (or utilise the successfully established connection) + { + State state = this.state; + if (state.isConnecting()) + return state.connecting().attempt; + } + + Promise promise = AsyncPromise.uncancellable(eventLoop); + runOnEventLoop(() -> { + if (isClosed()) // never going to connect + { + promise.tryFailure(new ClosedChannelException()); + } + else if (state.isEstablished() && state.established().isConnected()) // already connected + { + promise.trySuccess(null); + } + else + { + if (state.isEstablished()) + setDisconnected(); + + if (!state.isConnecting()) + { + assert eventLoop.inEventLoop(); + assert !isConnected(); + initiate().addListener(new PromiseNotifier<>(promise)); + } + else + { + state.connecting().attempt.addListener(new PromiseNotifier<>(promise)); + } + } + }); + return promise; + } + + /** + * Change the IP address on which we connect to the peer. We will attempt to connect to the new address if there + * was a previous connection, and new incoming messages as well as existing {@link #queue} messages will be sent there. + * Any outstanding messages in the existing channel will still be sent to the previous address (we won't/can't move them from + * one channel to another). + * + * Returns null if the connection is closed. + */ + Future reconnectWith(OutboundConnectionSettings reconnectWith) + { + OutboundConnectionSettings newTemplate = reconnectWith.withDefaults(ConnectionCategory.MESSAGING); + if (newTemplate.socketFactory != template.socketFactory) throw new IllegalArgumentException(); + if (newTemplate.callbacks != template.callbacks) throw new IllegalArgumentException(); + if (!Objects.equals(newTemplate.applicationSendQueueCapacityInBytes, template.applicationSendQueueCapacityInBytes)) throw new IllegalArgumentException(); + if (!Objects.equals(newTemplate.applicationSendQueueReserveEndpointCapacityInBytes, template.applicationSendQueueReserveEndpointCapacityInBytes)) throw new IllegalArgumentException(); + if (newTemplate.applicationSendQueueReserveGlobalCapacityInBytes != template.applicationSendQueueReserveGlobalCapacityInBytes) throw new IllegalArgumentException(); + + logger.info("{} updating connection settings", id()); + + Promise done = AsyncPromise.uncancellable(eventLoop); + delivery.stopAndRunOnEventLoop(() -> { + template = newTemplate; + // delivery will immediately continue after this, triggering a reconnect if necessary; + // this might mean a slight delay for large message delivery, as the connect will be scheduled + // asynchronously, so we must wait for a second turn on the eventLoop + if (state.isEstablished()) + { + disconnectNow(state.established()); + } + else if (state.isConnecting()) + { + // cancel any in-flight connection attempt and restart with new template + state.connecting().cancel(); + initiate(); + } + done.setSuccess(null); + }); + return done; + } + + /** + * Close any currently open connection, forcing a reconnect if there are messages outstanding + * (or leaving it closed for now otherwise) + */ + public boolean interrupt() + { + State state = this.state; + if (!state.isEstablished()) + return false; + + disconnectGracefully(state.established()); + return true; + } + + /** + * Schedule a safe close of the provided channel, if it has not already been closed. + * + * This means ensuring that delivery has stopped so that we do not corrupt or interrupt any + * in progress transmissions. + * + * The actual closing of the channel is performed asynchronously, to simplify our internal state management + * and promptly get the connection going again; the close is considered to have succeeded as soon as we + * have set our internal state. + */ + private void disconnectGracefully(Established closeIfIs) + { + // delivery will immediately continue after this, triggering a reconnect if necessary; + // this might mean a slight delay for large message delivery, as the connect will be scheduled + // asynchronously, so we must wait for a second turn on the eventLoop + delivery.stopAndRunOnEventLoop(() -> disconnectNow(closeIfIs)); + } + + /** + * The channel is already known to be invalid, so there's no point waiting for a clean break in delivery. + * + * Delivery will be executed again as soon as we have logically closed the channel; we do not wait + * for the channel to actually be closed. + * + * The Future returned _does_ wait for the channel to be completely closed, so that callers can wait to be sure + * all writes have been completed either successfully or not. + */ + private Future disconnectNow(Established closeIfIs) + { + return runOnEventLoop(() -> { + if (state == closeIfIs) + { + // no need to wait until the channel is closed to set ourselves as disconnected (and potentially open a new channel) + setDisconnected(); + if (hasPending()) + delivery.execute(); + closeIfIs.channel.close() + .addListener(future -> { + if (!future.isSuccess()) + logger.info("Problem closing channel {}", closeIfIs, future.cause()); + }); + } + }); + } + + /** + * Schedules regular cleaning of the connection's state while it is disconnected from its remote endpoint. + * + * To be run only by the eventLoop or in the constructor + */ + private void setDisconnected() + { + assert state == null || state.isEstablished(); + state = Disconnected.dormant(eventLoop.scheduleAtFixedRate(queue::maybePruneExpired, 100L, 100L, TimeUnit.MILLISECONDS)); + } + + /** + * Schedule this connection to be permanently closed; only one close may be scheduled, + * any future scheduled closes are referred to the original triggering one (which may have a different schedule) + */ + Future scheduleClose(long time, TimeUnit unit, boolean flushQueue) + { + Promise scheduledClose = AsyncPromise.uncancellable(eventLoop); + if (!scheduledCloseUpdater.compareAndSet(this, null, scheduledClose)) + return this.scheduledClose; + + eventLoop.schedule(() -> close(flushQueue).addListener(new PromiseNotifier<>(scheduledClose)), time, unit); + return scheduledClose; + } + + /** + * Permanently close this connection. + * + * Immediately prevent any new messages from being enqueued - these will throw ClosedChannelException. + * The close itself happens asynchronously on the eventLoop, so a Future is returned to help callers + * wait for its completion. + * + * The flushQueue parameter indicates if any outstanding messages should be delivered before closing the connection. + * + * - If false, any already flushed or in-progress messages are completed, and the remaining messages are cleared + * before the connection is promptly torn down. + * + * - If true, we attempt delivery of all queued messages. If necessary, we will continue to open new connections + * to the remote host until they have been delivered. Only if we continue to fail to open a connection for + * an extended period of time will we drop any outstanding messages and close the connection. + */ + public Future close(boolean flushQueue) + { + // ensure only one close attempt can be in flight + Promise closing = AsyncPromise.uncancellable(eventLoop); + if (!closingUpdater.compareAndSet(this, null, closing)) + return this.closing; + + /* + * Now define a cleanup closure, that will be deferred until it is safe to do so. + * Once run it: + * - immediately _logically_ closes the channel by updating this object's fields, but defers actually closing + * - cancels any in-flight connection attempts + * - cancels any maintenance work that might be scheduled + * - clears any waiting messages on the queue + * - terminates the delivery thread + * - finally, schedules any open channel's closure, and propagates its completion to the close promise + */ + Runnable eventLoopCleanup = () -> { + Runnable onceNotConnecting = () -> { + // start by setting ourselves to definitionally closed + State state = this.state; + this.state = State.CLOSED; + + try + { + // note that we never clear the queue, to ensure that an enqueue has the opportunity to remove itself + // if it raced with close, to potentially requeue the message on a replacement connection + + // we terminate delivery here, to ensure that any listener to {@link connecting} do not schedule more work + delivery.terminate(); + + // stop periodic cleanup + if (state.isDisconnected()) + { + state.disconnected().maintenance.cancel(true); + closing.setSuccess(null); + } + else + { + assert state.isEstablished(); + state.established().channel.close() + .addListener(new PromiseNotifier<>(closing)); + } + } + catch (Throwable t) + { + // in case of unexpected exception, signal completion and try to close the channel + closing.trySuccess(null); + try + { + if (state.isEstablished()) + state.established().channel.close(); + } + catch (Throwable t2) + { + t.addSuppressed(t2); + logger.error("Failed to close connection cleanly:", t); + } + throw t; + } + }; + + if (state.isConnecting()) + { + // stop any in-flight connection attempts; these should be running on the eventLoop, so we should + // be able to cleanly cancel them, but executing on a listener guarantees correct semantics either way + Connecting connecting = state.connecting(); + connecting.cancel(); + connecting.attempt.addListener(future -> onceNotConnecting.run()); + } + else + { + onceNotConnecting.run(); + } + }; + + /* + * If we want to shutdown gracefully, flushing any outstanding messages, we have to do it very carefully. + * Things to note: + * + * - It is possible flushing messages will require establishing a new connection + * (However, if a connection cannot be established, we do not want to keep trying) + * - We have to negotiate with a separate thread, so we must be sure it is not in-progress before we stop (like channel close) + * - Cleanup must still happen on the eventLoop + * + * To achieve all of this, we schedule a recurring operation on the delivery thread, executing while delivery + * is between messages, that checks if the queue is empty; if it is, it schedules cleanup on the eventLoop. + */ + + Runnable clearQueue = () -> + { + CountDownLatch done = new CountDownLatch(1); + queue.runEventually(withLock -> { + withLock.consume(this::onClosed); + done.countDown(); + }); + //noinspection UnstableApiUsage + Uninterruptibles.awaitUninterruptibly(done); + }; + + if (flushQueue) + { + // just keep scheduling on delivery executor a check to see if we're done; there should always be one + // delivery attempt between each invocation, unless there is a wider problem with delivery scheduling + class FinishDelivery implements Runnable + { + public void run() + { + if (!hasPending()) + delivery.stopAndRunOnEventLoop(eventLoopCleanup); + else + delivery.stopAndRun(() -> { + if (state.isConnecting() && state.connecting().isFailingToConnect) + clearQueue.run(); + run(); + }); + } + } + + delivery.stopAndRun(new FinishDelivery()); + } + else + { + delivery.stopAndRunOnEventLoop(() -> { + clearQueue.run(); + eventLoopCleanup.run(); + }); + } + + return closing; + } + + /** + * Run the task immediately if we are the eventLoop, otherwise queue it for execution on the eventLoop. + */ + private Future runOnEventLoop(Runnable runnable) + { + if (!eventLoop.inEventLoop()) + return eventLoop.submit(runnable); + + runnable.run(); + return new SucceededFuture<>(eventLoop, null); + } + + public boolean isConnected() + { + State state = this.state; + return state.isEstablished() && state.established().isConnected(); + } + + boolean isClosing() + { + return closing != null; + } + + boolean isClosed() + { + return state.isClosed(); + } + + private String id(boolean includeReal) + { + State state = this.state; + if (!includeReal || !state.isEstablished()) + return id(); + Established established = state.established(); + Channel channel = established.channel; + OutboundConnectionSettings settings = established.settings; + return SocketFactory.channelId(settings.from, (InetSocketAddress) channel.remoteAddress(), + settings.to, (InetSocketAddress) channel.localAddress(), + type, channel.id().asShortText()); + } + + private String id() + { + State state = this.state; + Channel channel = null; + OutboundConnectionSettings settings = template; + if (state.isEstablished()) + { + channel = state.established().channel; + settings = state.established().settings; + } + String channelId = channel != null ? channel.id().asShortText() : "[no-channel]"; + return SocketFactory.channelId(settings.from(), settings.to, type, channelId); + } + + @Override + public String toString() + { + return id(); + } + + public boolean hasPending() + { + return 0 != pendingCountAndBytes; + } + + public int pendingCount() + { + return pendingCount(pendingCountAndBytes); + } + + public long pendingBytes() + { + return pendingBytes(pendingCountAndBytes); + } + + public long sentCount() + { + // not volatile, but shouldn't matter + return sentCount; + } + + public long sentBytes() + { + // not volatile, but shouldn't matter + return sentBytes; + } + + public long submittedCount() + { + // not volatile, but shouldn't matter + return submittedCount; + } + + public long dropped() + { + return overloadedCount + expiredCount; + } + + public long overloadedBytes() + { + return overloadedBytes; + } + + public long overloadedCount() + { + return overloadedCount; + } + + public long expiredCount() + { + return expiredCount; + } + + public long expiredBytes() + { + return expiredBytes; + } + + public long errorCount() + { + return errorCount; + } + + public long errorBytes() + { + return errorBytes; + } + + public long successfulConnections() + { + return successfulConnections; + } + + public long connectionAttempts() + { + return connectionAttempts; + } + + private static Runnable andThen(Runnable a, Runnable b) + { + if (a == null || b == null) + return a == null ? b : a; + return () -> { a.run(); b.run(); }; + } + + @VisibleForTesting + public ConnectionType type() + { + return type; + } + + @VisibleForTesting + OutboundConnectionSettings settings() + { + State state = this.state; + return state.isEstablished() ? state.established().settings : template; + } + + @VisibleForTesting + int messagingVersion() + { + State state = this.state; + return state.isEstablished() ? state.established().messagingVersion + : template.endpointToVersion().get(template.to); + } + + @VisibleForTesting + void unsafeRunOnDelivery(Runnable run) + { + delivery.stopAndRun(run); + } + + @VisibleForTesting + Channel unsafeGetChannel() + { + State state = this.state; + return state.isEstablished() ? state.established().channel : null; + } + + @VisibleForTesting + boolean unsafeAcquireCapacity(long amount) + { + return SUCCESS == acquireCapacity(amount); + } + + @VisibleForTesting + boolean unsafeAcquireCapacity(long count, long amount) + { + return SUCCESS == acquireCapacity(count, amount); + } + + @VisibleForTesting + void unsafeReleaseCapacity(long amount) + { + releaseCapacity(1, amount); + } + + Limit unsafeGetEndpointReserveLimits() + { + return reserveCapacityInBytes.endpoint; + } +} \ No newline at end of file diff --git a/src/java/org/apache/cassandra/net/OutboundConnectionInitiator.java b/src/java/org/apache/cassandra/net/OutboundConnectionInitiator.java new file mode 100644 index 000000000000..fdfb2dfa74e1 --- /dev/null +++ b/src/java/org/apache/cassandra/net/OutboundConnectionInitiator.java @@ -0,0 +1,471 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.net; + +import java.io.IOException; +import java.net.InetSocketAddress; +import java.nio.channels.ClosedChannelException; +import java.util.List; +import java.util.concurrent.atomic.AtomicBoolean; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import io.netty.bootstrap.Bootstrap; +import io.netty.buffer.ByteBuf; +import io.netty.channel.Channel; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandler; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.ChannelOption; +import io.netty.channel.ChannelPipeline; +import io.netty.channel.EventLoop; +import io.netty.channel.socket.SocketChannel; +import io.netty.handler.codec.ByteToMessageDecoder; + +import io.netty.handler.logging.LogLevel; +import io.netty.handler.logging.LoggingHandler; +import io.netty.handler.ssl.SslContext; +import io.netty.handler.ssl.SslHandler; +import io.netty.util.concurrent.FailedFuture; +import io.netty.util.concurrent.Future; +import io.netty.util.concurrent.Promise; +import io.netty.util.concurrent.ScheduledFuture; +import org.apache.cassandra.locator.InetAddressAndPort; +import org.apache.cassandra.net.HandshakeProtocol.Initiate; +import org.apache.cassandra.net.OutboundConnectionInitiator.Result.MessagingSuccess; +import org.apache.cassandra.net.OutboundConnectionInitiator.Result.StreamingSuccess; +import org.apache.cassandra.security.SSLFactory; +import org.apache.cassandra.utils.JVMStabilityInspector; +import org.apache.cassandra.utils.memory.BufferPool; + +import static java.util.concurrent.TimeUnit.*; +import static org.apache.cassandra.net.MessagingService.VERSION_40; +import static org.apache.cassandra.net.HandshakeProtocol.*; +import static org.apache.cassandra.net.ConnectionType.STREAMING; +import static org.apache.cassandra.net.OutboundConnectionInitiator.Result.incompatible; +import static org.apache.cassandra.net.OutboundConnectionInitiator.Result.messagingSuccess; +import static org.apache.cassandra.net.OutboundConnectionInitiator.Result.retry; +import static org.apache.cassandra.net.OutboundConnectionInitiator.Result.streamingSuccess; +import static org.apache.cassandra.net.SocketFactory.*; + +/** + * A {@link ChannelHandler} to execute the send-side of the internode handshake protocol. + * As soon as the handler is added to the channel via {@link ChannelInboundHandler#channelActive(ChannelHandlerContext)} + * (which is only invoked if the underlying TCP connection was properly established), the {@link Initiate} + * handshake is sent. See {@link HandshakeProtocol} for full details. + *

+ * Upon completion of the handshake (on success or fail), the {@link #resultPromise} is completed. + * See {@link Result} for details about the different result states. + *

+ * This class extends {@link ByteToMessageDecoder}, which is a {@link ChannelInboundHandler}, because this handler + * waits for the peer's handshake response (the {@link Accept} of the internode messaging handshake protocol). + */ +public class OutboundConnectionInitiator +{ + private static final Logger logger = LoggerFactory.getLogger(OutboundConnectionInitiator.class); + + private final ConnectionType type; + private final OutboundConnectionSettings settings; + private final int requestMessagingVersion; // for pre40 nodes + private final Promise> resultPromise; + + private OutboundConnectionInitiator(ConnectionType type, OutboundConnectionSettings settings, + int requestMessagingVersion, Promise> resultPromise) + { + this.type = type; + this.requestMessagingVersion = requestMessagingVersion; + this.settings = settings; + this.resultPromise = resultPromise; + } + + /** + * Initiate a connection with the requested messaging version. + * if the other node supports a newer version, or doesn't support this version, we will fail to connect + * and try again with the version they reported + * + * The returned {@code Future} is guaranteed to be completed on the supplied eventLoop. + */ + public static Future> initiateStreaming(EventLoop eventLoop, OutboundConnectionSettings settings, int requestMessagingVersion) + { + return new OutboundConnectionInitiator(STREAMING, settings, requestMessagingVersion, new AsyncPromise<>(eventLoop)) + .initiate(eventLoop); + } + + /** + * Initiate a connection with the requested messaging version. + * if the other node supports a newer version, or doesn't support this version, we will fail to connect + * and try again with the version they reported + * + * The returned {@code Future} is guaranteed to be completed on the supplied eventLoop. + */ + static Future> initiateMessaging(EventLoop eventLoop, ConnectionType type, OutboundConnectionSettings settings, int requestMessagingVersion, Promise> result) + { + return new OutboundConnectionInitiator<>(type, settings, requestMessagingVersion, result) + .initiate(eventLoop); + } + + private Future> initiate(EventLoop eventLoop) + { + if (logger.isTraceEnabled()) + logger.trace("creating outbound bootstrap to {}, requestVersion: {}", settings, requestMessagingVersion); + + if (!settings.authenticate()) + { + // interrupt other connections, so they must attempt to re-authenticate + MessagingService.instance().interruptOutbound(settings.to); + return new FailedFuture<>(eventLoop, new IOException("authentication failed to " + settings.to)); + } + + // this is a bit ugly, but is the easiest way to ensure that if we timeout we can propagate a suitable error message + // and still guarantee that, if on timing out we raced with success, the successfully created channel is handled + AtomicBoolean timedout = new AtomicBoolean(); + Future bootstrap = createBootstrap(eventLoop) + .connect() + .addListener(future -> { + eventLoop.execute(() -> { + if (!future.isSuccess()) + { + if (future.isCancelled() && !timedout.get()) + resultPromise.cancel(true); + else if (future.isCancelled()) + resultPromise.tryFailure(new IOException("Timeout handshaking with " + settings.connectTo)); + else + resultPromise.tryFailure(future.cause()); + } + }); + }); + + ScheduledFuture timeout = eventLoop.schedule(() -> { + timedout.set(true); + bootstrap.cancel(false); + }, TIMEOUT_MILLIS, MILLISECONDS); + bootstrap.addListener(future -> timeout.cancel(true)); + + // Note that the bootstrap future's listeners may be invoked outside of the eventLoop, + // as Epoll failures on connection and disconnect may be run on the GlobalEventExecutor + // Since this FutureResult's listeners are all given to our resultPromise, they are guaranteed to be invoked by the eventLoop. + return new FutureResult<>(resultPromise, bootstrap); + } + + /** + * Create the {@link Bootstrap} for connecting to a remote peer. This method does not attempt to connect to the peer, + * and thus does not block. + */ + private Bootstrap createBootstrap(EventLoop eventLoop) + { + Bootstrap bootstrap = settings.socketFactory + .newClientBootstrap(eventLoop, settings.tcpUserTimeoutInMS) + .option(ChannelOption.ALLOCATOR, GlobalBufferPoolAllocator.instance) + .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, settings.tcpConnectTimeoutInMS) + .option(ChannelOption.SO_KEEPALIVE, true) + .option(ChannelOption.SO_REUSEADDR, true) + .option(ChannelOption.TCP_NODELAY, settings.tcpNoDelay) + .option(ChannelOption.MESSAGE_SIZE_ESTIMATOR, NoSizeEstimator.instance) + .handler(new Initializer()); + + if (settings.socketSendBufferSizeInBytes > 0) + bootstrap.option(ChannelOption.SO_SNDBUF, settings.socketSendBufferSizeInBytes); + + InetAddressAndPort remoteAddress = settings.connectTo; + bootstrap.remoteAddress(new InetSocketAddress(remoteAddress.address, remoteAddress.port)); + return bootstrap; + } + + private class Initializer extends ChannelInitializer + { + public void initChannel(SocketChannel channel) throws Exception + { + ChannelPipeline pipeline = channel.pipeline(); + + // order of handlers: ssl -> logger -> handshakeHandler + if (settings.withEncryption()) + { + // check if we should actually encrypt this connection + SslContext sslContext = SSLFactory.getOrCreateSslContext(settings.encryption, true, SSLFactory.SocketType.CLIENT); + // for some reason channel.remoteAddress() will return null + InetAddressAndPort address = settings.to; + InetSocketAddress peer = settings.encryption.require_endpoint_verification ? new InetSocketAddress(address.address, address.port) : null; + SslHandler sslHandler = newSslHandler(channel, sslContext, peer); + logger.trace("creating outbound netty SslContext: context={}, engine={}", sslContext.getClass().getName(), sslHandler.engine().getClass().getName()); + pipeline.addFirst("ssl", sslHandler); + } + + if (WIRETRACE) + pipeline.addLast("logger", new LoggingHandler(LogLevel.INFO)); + + pipeline.addLast("handshake", new Handler()); + } + + } + + private class Handler extends ByteToMessageDecoder + { + /** + * {@inheritDoc} + * + * Invoked when the channel is made active, and sends out the {@link Initiate}. + * In the case of streaming, we do not require a full bi-directional handshake; the initial message, + * containing the streaming protocol version, is all that is required. + */ + @Override + public void channelActive(final ChannelHandlerContext ctx) + { + Initiate msg = new Initiate(requestMessagingVersion, settings.acceptVersions, type, settings.framing, settings.from); + logger.trace("starting handshake with peer {}, msg = {}", settings.connectTo, msg); + AsyncChannelPromise.writeAndFlush(ctx, msg.encode(), + future -> { if (!future.isSuccess()) exceptionCaught(ctx, future.cause()); }); + + if (type.isStreaming() && requestMessagingVersion < VERSION_40) + ctx.pipeline().remove(this); + + ctx.fireChannelActive(); + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) throws Exception + { + super.channelInactive(ctx); + resultPromise.tryFailure(new ClosedChannelException()); + } + + /** + * {@inheritDoc} + * + * Invoked when we get the response back from the peer, which should contain the second message of the internode messaging handshake. + *

+ * If the peer's protocol version does not equal what we were expecting, immediately close the channel (and socket); + * do *not* send out the third message of the internode messaging handshake. + * We will reconnect on the appropriate protocol version. + */ + @Override + protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) + { + try + { + Accept msg = Accept.maybeDecode(in, requestMessagingVersion); + if (msg == null) + return; + + int useMessagingVersion = msg.useMessagingVersion; + int peerMessagingVersion = msg.maxMessagingVersion; + logger.trace("received second handshake message from peer {}, msg = {}", settings.connectTo, msg); + + FrameEncoder frameEncoder = null; + Result result; + if (useMessagingVersion > 0) + { + if (useMessagingVersion < settings.acceptVersions.min || useMessagingVersion > settings.acceptVersions.max) + { + result = incompatible(useMessagingVersion, peerMessagingVersion); + } + else + { + // This is a bit ugly + if (type.isMessaging()) + { + switch (settings.framing) + { + case LZ4: + frameEncoder = FrameEncoderLZ4.fastInstance; + break; + case CRC: + frameEncoder = FrameEncoderCrc.instance; + break; + case UNPROTECTED: + frameEncoder = FrameEncoderUnprotected.instance; + break; + } + + result = (Result) messagingSuccess(ctx.channel(), useMessagingVersion, frameEncoder.allocator()); + } + else + { + result = (Result) streamingSuccess(ctx.channel(), useMessagingVersion); + } + } + } + else + { + assert type.isMessaging(); + + // pre40 handshake responses only (can be a post40 node) + if (peerMessagingVersion == requestMessagingVersion + || peerMessagingVersion > settings.acceptVersions.max) // this clause is for impersonating 3.0 node in testing only + { + switch (settings.framing) + { + case CRC: + case UNPROTECTED: + frameEncoder = FrameEncoderLegacy.instance; + break; + case LZ4: + frameEncoder = FrameEncoderLegacyLZ4.instance; + break; + } + + result = (Result) messagingSuccess(ctx.channel(), requestMessagingVersion, frameEncoder.allocator()); + } + else if (peerMessagingVersion < settings.acceptVersions.min) + result = incompatible(-1, peerMessagingVersion); + else + result = retry(peerMessagingVersion); + + if (result.isSuccess()) + { + ConfirmOutboundPre40 message = new ConfirmOutboundPre40(settings.acceptVersions.max, settings.from); + AsyncChannelPromise.writeAndFlush(ctx, message.encode()); + } + } + + ChannelPipeline pipeline = ctx.pipeline(); + if (result.isSuccess()) + { + BufferPool.setRecycleWhenFreeForCurrentThread(false); + if (type.isMessaging()) + { + assert frameEncoder != null; + pipeline.addLast("frameEncoder", frameEncoder); + } + pipeline.remove(this); + } + else + { + pipeline.close(); + } + + if (!resultPromise.trySuccess(result) && result.isSuccess()) + result.success().channel.close(); + } + catch (Throwable t) + { + exceptionCaught(ctx, t); + } + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) + { + try + { + JVMStabilityInspector.inspectThrowable(cause, false); + resultPromise.tryFailure(cause); + if (isCausedByConnectionReset(cause)) + logger.info("Failed to connect to peer {}", settings.to, cause); + else + logger.error("Failed to handshake with peer {}", settings.to, cause); + ctx.close(); + } + catch (Throwable t) + { + logger.error("Unexpected exception in {}.exceptionCaught", this.getClass().getSimpleName(), t); + } + } + } + + /** + * The result of the handshake. Handshake has 3 possible outcomes: + * 1) it can be successful, in which case the channel and version to used is returned in this result. + * 2) we may decide to disconnect to reconnect with another protocol version (namely, the version is passed in this result). + * 3) we can have a negotiation failure for an unknown reason. (#sadtrombone) + */ + public static class Result + { + /** + * Describes the result of receiving the response back from the peer (Message 2 of the handshake) + * and implies an action that should be taken. + */ + enum Outcome + { + SUCCESS, RETRY, INCOMPATIBLE + } + + public static class Success extends Result + { + public final Channel channel; + public final int messagingVersion; + Success(Channel channel, int messagingVersion) + { + super(Outcome.SUCCESS); + this.channel = channel; + this.messagingVersion = messagingVersion; + } + } + + public static class StreamingSuccess extends Success + { + StreamingSuccess(Channel channel, int messagingVersion) + { + super(channel, messagingVersion); + } + } + + public static class MessagingSuccess extends Success + { + public final FrameEncoder.PayloadAllocator allocator; + MessagingSuccess(Channel channel, int messagingVersion, FrameEncoder.PayloadAllocator allocator) + { + super(channel, messagingVersion); + this.allocator = allocator; + } + } + + static class Retry extends Result + { + final int withMessagingVersion; + Retry(int withMessagingVersion) + { + super(Outcome.RETRY); + this.withMessagingVersion = withMessagingVersion; + } + } + + static class Incompatible extends Result + { + final int closestSupportedVersion; + final int maxMessagingVersion; + Incompatible(int closestSupportedVersion, int maxMessagingVersion) + { + super(Outcome.INCOMPATIBLE); + this.closestSupportedVersion = closestSupportedVersion; + this.maxMessagingVersion = maxMessagingVersion; + } + } + + final Outcome outcome; + + private Result(Outcome outcome) + { + this.outcome = outcome; + } + + boolean isSuccess() { return outcome == Outcome.SUCCESS; } + public SuccessType success() { return (SuccessType) this; } + static MessagingSuccess messagingSuccess(Channel channel, int messagingVersion, FrameEncoder.PayloadAllocator allocator) { return new MessagingSuccess(channel, messagingVersion, allocator); } + static StreamingSuccess streamingSuccess(Channel channel, int messagingVersion) { return new StreamingSuccess(channel, messagingVersion); } + + public Retry retry() { return (Retry) this; } + static Result retry(int withMessagingVersion) { return new Retry<>(withMessagingVersion); } + + public Incompatible incompatible() { return (Incompatible) this; } + static Result incompatible(int closestSupportedVersion, int maxMessagingVersion) { return new Incompatible(closestSupportedVersion, maxMessagingVersion); } + } + +} diff --git a/src/java/org/apache/cassandra/net/OutboundConnectionSettings.java b/src/java/org/apache/cassandra/net/OutboundConnectionSettings.java new file mode 100644 index 000000000000..c78df61326c3 --- /dev/null +++ b/src/java/org/apache/cassandra/net/OutboundConnectionSettings.java @@ -0,0 +1,517 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.net; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Preconditions; + +import io.netty.channel.WriteBufferWaterMark; +import org.apache.cassandra.auth.IInternodeAuthenticator; +import org.apache.cassandra.config.Config; +import org.apache.cassandra.config.DatabaseDescriptor; +import org.apache.cassandra.config.EncryptionOptions; +import org.apache.cassandra.config.EncryptionOptions.ServerEncryptionOptions; +import org.apache.cassandra.db.SystemKeyspace; +import org.apache.cassandra.locator.IEndpointSnitch; +import org.apache.cassandra.locator.InetAddressAndPort; +import org.apache.cassandra.utils.FBUtilities; + +import static org.apache.cassandra.config.DatabaseDescriptor.getEndpointSnitch; +import static org.apache.cassandra.net.MessagingService.VERSION_40; +import static org.apache.cassandra.net.MessagingService.instance; +import static org.apache.cassandra.net.SocketFactory.encryptionLogStatement; +import static org.apache.cassandra.utils.FBUtilities.getBroadcastAddressAndPort; + +/** + * A collection of settings to be passed around for outbound connections. + */ +@SuppressWarnings({ "WeakerAccess", "unused" }) +public class OutboundConnectionSettings +{ + private static final String INTRADC_TCP_NODELAY_PROPERTY = Config.PROPERTY_PREFIX + "otc_intradc_tcp_nodelay"; + /** + * Enabled/disable TCP_NODELAY for intradc connections. Defaults to enabled. + */ + private static final boolean INTRADC_TCP_NODELAY = Boolean.parseBoolean(System.getProperty(INTRADC_TCP_NODELAY_PROPERTY, "true")); + + public enum Framing + { + // for < VERSION_40, implies no framing + // for >= VERSION_40, uses simple unprotected frames with header crc but no payload protection + UNPROTECTED(0), + // for < VERSION_40, uses the jpountz framing format + // for >= VERSION_40, uses our framing format with header crc24 + LZ4(1), + // for < VERSION_40, implies UNPROTECTED + // for >= VERSION_40, uses simple frames with separate header and payload crc + CRC(2); + + public static Framing forId(int id) + { + switch (id) + { + case 0: return UNPROTECTED; + case 1: return LZ4; + case 2: return CRC; + } + throw new IllegalStateException(); + } + + final int id; + Framing(int id) + { + this.id = id; + } + } + + public final IInternodeAuthenticator authenticator; + public final InetAddressAndPort to; + public final InetAddressAndPort connectTo; // may be represented by a different IP address on this node's local network + public final EncryptionOptions encryption; + public final Framing framing; + public final Integer socketSendBufferSizeInBytes; + public final Integer applicationSendQueueCapacityInBytes; + public final Integer applicationSendQueueReserveEndpointCapacityInBytes; + public final ResourceLimits.Limit applicationSendQueueReserveGlobalCapacityInBytes; + public final Boolean tcpNoDelay; + public final int flushLowWaterMark, flushHighWaterMark; + public final Integer tcpConnectTimeoutInMS; + public final Integer tcpUserTimeoutInMS; + public final AcceptVersions acceptVersions; + public final InetAddressAndPort from; + public final SocketFactory socketFactory; + public final OutboundMessageCallbacks callbacks; + public final OutboundDebugCallbacks debug; + public final EndpointMessagingVersions endpointToVersion; + + public OutboundConnectionSettings(InetAddressAndPort to) + { + this(to, null); + } + + public OutboundConnectionSettings(InetAddressAndPort to, InetAddressAndPort preferred) + { + this(null, to, preferred, null, null, null, null, null, null, null, 1 << 15, 1 << 16, null, null, null, null, null, null, null, null); + } + + private OutboundConnectionSettings(IInternodeAuthenticator authenticator, + InetAddressAndPort to, + InetAddressAndPort connectTo, + EncryptionOptions encryption, + Framing framing, + Integer socketSendBufferSizeInBytes, + Integer applicationSendQueueCapacityInBytes, + Integer applicationSendQueueReserveEndpointCapacityInBytes, + ResourceLimits.Limit applicationSendQueueReserveGlobalCapacityInBytes, + Boolean tcpNoDelay, + int flushLowWaterMark, + int flushHighWaterMark, + Integer tcpConnectTimeoutInMS, + Integer tcpUserTimeoutInMS, + AcceptVersions acceptVersions, + InetAddressAndPort from, + SocketFactory socketFactory, + OutboundMessageCallbacks callbacks, + OutboundDebugCallbacks debug, + EndpointMessagingVersions endpointToVersion) + { + Preconditions.checkArgument(socketSendBufferSizeInBytes == null || socketSendBufferSizeInBytes == 0 || socketSendBufferSizeInBytes >= 1 << 10, "illegal socket send buffer size: " + socketSendBufferSizeInBytes); + Preconditions.checkArgument(applicationSendQueueCapacityInBytes == null || applicationSendQueueCapacityInBytes >= 1 << 10, "illegal application send queue capacity: " + applicationSendQueueCapacityInBytes); + Preconditions.checkArgument(tcpUserTimeoutInMS == null || tcpUserTimeoutInMS >= 0, "tcp user timeout must be non negative: " + tcpUserTimeoutInMS); + Preconditions.checkArgument(tcpConnectTimeoutInMS == null || tcpConnectTimeoutInMS > 0, "tcp connect timeout must be positive: " + tcpConnectTimeoutInMS); + + this.authenticator = authenticator; + this.to = to; + this.connectTo = connectTo; + this.encryption = encryption; + this.framing = framing; + this.socketSendBufferSizeInBytes = socketSendBufferSizeInBytes; + this.applicationSendQueueCapacityInBytes = applicationSendQueueCapacityInBytes; + this.applicationSendQueueReserveEndpointCapacityInBytes = applicationSendQueueReserveEndpointCapacityInBytes; + this.applicationSendQueueReserveGlobalCapacityInBytes = applicationSendQueueReserveGlobalCapacityInBytes; + this.tcpNoDelay = tcpNoDelay; + this.flushLowWaterMark = flushLowWaterMark; + this.flushHighWaterMark = flushHighWaterMark; + this.tcpConnectTimeoutInMS = tcpConnectTimeoutInMS; + this.tcpUserTimeoutInMS = tcpUserTimeoutInMS; + this.acceptVersions = acceptVersions; + this.from = from; + this.socketFactory = socketFactory; + this.callbacks = callbacks; + this.debug = debug; + this.endpointToVersion = endpointToVersion; + } + + public boolean authenticate() + { + return authenticator.authenticate(to.address, to.port); + } + + public boolean withEncryption() + { + return encryption != null; + } + + public String toString() + { + return String.format("peer: (%s, %s), framing: %s, encryption: %s", + to, connectTo, framing, encryptionLogStatement(encryption)); + } + + public OutboundConnectionSettings withAuthenticator(IInternodeAuthenticator authenticator) + { + return new OutboundConnectionSettings(authenticator, to, connectTo, encryption, framing, + socketSendBufferSizeInBytes, applicationSendQueueCapacityInBytes, + applicationSendQueueReserveEndpointCapacityInBytes, applicationSendQueueReserveGlobalCapacityInBytes, + tcpNoDelay, flushLowWaterMark, flushHighWaterMark, tcpConnectTimeoutInMS, + tcpUserTimeoutInMS, acceptVersions, from, socketFactory, callbacks, debug, endpointToVersion); + } + + @SuppressWarnings("unused") + public OutboundConnectionSettings toEndpoint(InetAddressAndPort endpoint) + { + return new OutboundConnectionSettings(authenticator, endpoint, connectTo, encryption, framing, + socketSendBufferSizeInBytes, applicationSendQueueCapacityInBytes, + applicationSendQueueReserveEndpointCapacityInBytes, applicationSendQueueReserveGlobalCapacityInBytes, + tcpNoDelay, flushLowWaterMark, flushHighWaterMark, tcpConnectTimeoutInMS, + tcpUserTimeoutInMS, acceptVersions, from, socketFactory, callbacks, debug, endpointToVersion); + } + + public OutboundConnectionSettings withConnectTo(InetAddressAndPort connectTo) + { + return new OutboundConnectionSettings(authenticator, to, connectTo, encryption, framing, + socketSendBufferSizeInBytes, applicationSendQueueCapacityInBytes, + applicationSendQueueReserveEndpointCapacityInBytes, applicationSendQueueReserveGlobalCapacityInBytes, + tcpNoDelay, flushLowWaterMark, flushHighWaterMark, tcpConnectTimeoutInMS, + tcpUserTimeoutInMS, acceptVersions, from, socketFactory, callbacks, debug, endpointToVersion); + } + + public OutboundConnectionSettings withEncryption(ServerEncryptionOptions encryption) + { + return new OutboundConnectionSettings(authenticator, to, connectTo, encryption, framing, + socketSendBufferSizeInBytes, applicationSendQueueCapacityInBytes, + applicationSendQueueReserveEndpointCapacityInBytes, applicationSendQueueReserveGlobalCapacityInBytes, + tcpNoDelay, flushLowWaterMark, flushHighWaterMark, tcpConnectTimeoutInMS, + tcpUserTimeoutInMS, acceptVersions, from, socketFactory, callbacks, debug, endpointToVersion); + } + + @SuppressWarnings("unused") + public OutboundConnectionSettings withFraming(Framing framing) + { + return new OutboundConnectionSettings(authenticator, to, connectTo, encryption, framing, socketSendBufferSizeInBytes, applicationSendQueueCapacityInBytes, + applicationSendQueueReserveEndpointCapacityInBytes, applicationSendQueueReserveGlobalCapacityInBytes, + tcpNoDelay, flushLowWaterMark, flushHighWaterMark, tcpConnectTimeoutInMS, + tcpUserTimeoutInMS, acceptVersions, from, socketFactory, callbacks, debug, endpointToVersion); + } + + public OutboundConnectionSettings withSocketSendBufferSizeInBytes(int socketSendBufferSizeInBytes) + { + return new OutboundConnectionSettings(authenticator, to, connectTo, encryption, framing, + socketSendBufferSizeInBytes, applicationSendQueueCapacityInBytes, + applicationSendQueueReserveEndpointCapacityInBytes, applicationSendQueueReserveGlobalCapacityInBytes, + tcpNoDelay, flushLowWaterMark, flushHighWaterMark, tcpConnectTimeoutInMS, + tcpUserTimeoutInMS, acceptVersions, from, socketFactory, callbacks, debug, endpointToVersion); + } + + @SuppressWarnings("unused") + public OutboundConnectionSettings withApplicationSendQueueCapacityInBytes(int applicationSendQueueCapacityInBytes) + { + return new OutboundConnectionSettings(authenticator, to, connectTo, encryption, framing, + socketSendBufferSizeInBytes, applicationSendQueueCapacityInBytes, + applicationSendQueueReserveEndpointCapacityInBytes, applicationSendQueueReserveGlobalCapacityInBytes, + tcpNoDelay, flushLowWaterMark, flushHighWaterMark, tcpConnectTimeoutInMS, + tcpUserTimeoutInMS, acceptVersions, from, socketFactory, callbacks, debug, endpointToVersion); + } + + public OutboundConnectionSettings withApplicationReserveSendQueueCapacityInBytes(Integer applicationReserveSendQueueEndpointCapacityInBytes, ResourceLimits.Limit applicationReserveSendQueueGlobalCapacityInBytes) + { + return new OutboundConnectionSettings(authenticator, to, connectTo, encryption, framing, + socketSendBufferSizeInBytes, applicationSendQueueCapacityInBytes, + applicationReserveSendQueueEndpointCapacityInBytes, applicationReserveSendQueueGlobalCapacityInBytes, + tcpNoDelay, flushLowWaterMark, flushHighWaterMark, tcpConnectTimeoutInMS, + tcpUserTimeoutInMS, acceptVersions, from, socketFactory, callbacks, debug, endpointToVersion); + } + + @SuppressWarnings("unused") + public OutboundConnectionSettings withTcpNoDelay(boolean tcpNoDelay) + { + return new OutboundConnectionSettings(authenticator, to, connectTo, encryption, framing, + socketSendBufferSizeInBytes, applicationSendQueueCapacityInBytes, + applicationSendQueueReserveEndpointCapacityInBytes, applicationSendQueueReserveGlobalCapacityInBytes, + tcpNoDelay, flushLowWaterMark, flushHighWaterMark, tcpConnectTimeoutInMS, + tcpUserTimeoutInMS, acceptVersions, from, socketFactory, callbacks, debug, endpointToVersion); + } + + @SuppressWarnings("unused") + public OutboundConnectionSettings withNettyBufferBounds(WriteBufferWaterMark nettyBufferBounds) + { + return new OutboundConnectionSettings(authenticator, to, connectTo, encryption, framing, + socketSendBufferSizeInBytes, applicationSendQueueCapacityInBytes, + applicationSendQueueReserveEndpointCapacityInBytes, applicationSendQueueReserveGlobalCapacityInBytes, + tcpNoDelay, flushLowWaterMark, flushHighWaterMark, tcpConnectTimeoutInMS, + tcpUserTimeoutInMS, acceptVersions, from, socketFactory, callbacks, debug, endpointToVersion); + } + + public OutboundConnectionSettings withTcpConnectTimeoutInMS(int tcpConnectTimeoutInMS) + { + return new OutboundConnectionSettings(authenticator, to, connectTo, encryption, framing, + socketSendBufferSizeInBytes, applicationSendQueueCapacityInBytes, + applicationSendQueueReserveEndpointCapacityInBytes, applicationSendQueueReserveGlobalCapacityInBytes, + tcpNoDelay, flushLowWaterMark, flushHighWaterMark, tcpConnectTimeoutInMS, + tcpUserTimeoutInMS, acceptVersions, from, socketFactory, callbacks, debug, endpointToVersion); + } + + public OutboundConnectionSettings withTcpUserTimeoutInMS(int tcpUserTimeoutInMS) + { + return new OutboundConnectionSettings(authenticator, to, connectTo, encryption, framing, + socketSendBufferSizeInBytes, applicationSendQueueCapacityInBytes, + applicationSendQueueReserveEndpointCapacityInBytes, applicationSendQueueReserveGlobalCapacityInBytes, + tcpNoDelay, flushLowWaterMark, flushHighWaterMark, tcpConnectTimeoutInMS, + tcpUserTimeoutInMS, acceptVersions, from, socketFactory, callbacks, debug, endpointToVersion); + } + + public OutboundConnectionSettings withAcceptVersions(AcceptVersions acceptVersions) + { + return new OutboundConnectionSettings(authenticator, to, connectTo, encryption, framing, + socketSendBufferSizeInBytes, applicationSendQueueCapacityInBytes, + applicationSendQueueReserveEndpointCapacityInBytes, applicationSendQueueReserveGlobalCapacityInBytes, + tcpNoDelay, flushLowWaterMark, flushHighWaterMark, tcpConnectTimeoutInMS, + tcpUserTimeoutInMS, acceptVersions, from, socketFactory, callbacks, debug, endpointToVersion); + } + + public OutboundConnectionSettings withFrom(InetAddressAndPort from) + { + return new OutboundConnectionSettings(authenticator, to, connectTo, encryption, framing, + socketSendBufferSizeInBytes, applicationSendQueueCapacityInBytes, + applicationSendQueueReserveEndpointCapacityInBytes, applicationSendQueueReserveGlobalCapacityInBytes, + tcpNoDelay, flushLowWaterMark, flushHighWaterMark, tcpConnectTimeoutInMS, + tcpUserTimeoutInMS, acceptVersions, from, socketFactory, callbacks, debug, endpointToVersion); + } + + public OutboundConnectionSettings withSocketFactory(SocketFactory socketFactory) + { + return new OutboundConnectionSettings(authenticator, to, connectTo, encryption, framing, + socketSendBufferSizeInBytes, applicationSendQueueCapacityInBytes, + applicationSendQueueReserveEndpointCapacityInBytes, applicationSendQueueReserveGlobalCapacityInBytes, + tcpNoDelay, flushLowWaterMark, flushHighWaterMark, tcpConnectTimeoutInMS, + tcpUserTimeoutInMS, acceptVersions, from, socketFactory, callbacks, debug, endpointToVersion); + } + + public OutboundConnectionSettings withCallbacks(OutboundMessageCallbacks callbacks) + { + return new OutboundConnectionSettings(authenticator, to, connectTo, encryption, framing, + socketSendBufferSizeInBytes, applicationSendQueueCapacityInBytes, + applicationSendQueueReserveEndpointCapacityInBytes, applicationSendQueueReserveGlobalCapacityInBytes, + tcpNoDelay, flushLowWaterMark, flushHighWaterMark, tcpConnectTimeoutInMS, + tcpUserTimeoutInMS, acceptVersions, from, socketFactory, callbacks, debug, endpointToVersion); + } + + public OutboundConnectionSettings withDebugCallbacks(OutboundDebugCallbacks debug) + { + return new OutboundConnectionSettings(authenticator, to, connectTo, encryption, framing, + socketSendBufferSizeInBytes, applicationSendQueueCapacityInBytes, + applicationSendQueueReserveEndpointCapacityInBytes, applicationSendQueueReserveGlobalCapacityInBytes, + tcpNoDelay, flushLowWaterMark, flushHighWaterMark, tcpConnectTimeoutInMS, + tcpUserTimeoutInMS, acceptVersions, from, socketFactory, callbacks, debug, endpointToVersion); + } + + public OutboundConnectionSettings withDefaultReserveLimits() + { + Integer applicationReserveSendQueueEndpointCapacityInBytes = this.applicationSendQueueReserveEndpointCapacityInBytes; + ResourceLimits.Limit applicationReserveSendQueueGlobalCapacityInBytes = this.applicationSendQueueReserveGlobalCapacityInBytes; + + if (applicationReserveSendQueueEndpointCapacityInBytes == null) + applicationReserveSendQueueEndpointCapacityInBytes = DatabaseDescriptor.getInternodeApplicationSendQueueReserveEndpointCapacityInBytes(); + if (applicationReserveSendQueueGlobalCapacityInBytes == null) + applicationReserveSendQueueGlobalCapacityInBytes = MessagingService.instance().outboundGlobalReserveLimit; + + return withApplicationReserveSendQueueCapacityInBytes(applicationReserveSendQueueEndpointCapacityInBytes, applicationReserveSendQueueGlobalCapacityInBytes); + } + + public IInternodeAuthenticator authenticator() + { + return authenticator != null ? authenticator : DatabaseDescriptor.getInternodeAuthenticator(); + } + + public EndpointMessagingVersions endpointToVersion() + { + if (endpointToVersion == null) + return instance().versions; + return endpointToVersion; + } + + public InetAddressAndPort from() + { + return from != null ? from : FBUtilities.getBroadcastAddressAndPort(); + } + + public OutboundDebugCallbacks debug() + { + return debug != null ? debug : OutboundDebugCallbacks.NONE; + } + + public EncryptionOptions encryption() + { + return encryption != null ? encryption : defaultEncryptionOptions(to); + } + + public SocketFactory socketFactory() + { + return socketFactory != null ? socketFactory : instance().socketFactory; + } + + public OutboundMessageCallbacks callbacks() + { + return callbacks != null ? callbacks : instance().callbacks; + } + + public int socketSendBufferSizeInBytes() + { + return socketSendBufferSizeInBytes != null ? socketSendBufferSizeInBytes + : DatabaseDescriptor.getInternodeSocketSendBufferSizeInBytes(); + } + + public int applicationSendQueueCapacityInBytes() + { + return applicationSendQueueCapacityInBytes != null ? applicationSendQueueCapacityInBytes + : DatabaseDescriptor.getInternodeApplicationSendQueueCapacityInBytes(); + } + + public ResourceLimits.Limit applicationSendQueueReserveGlobalCapacityInBytes() + { + return applicationSendQueueReserveGlobalCapacityInBytes != null ? applicationSendQueueReserveGlobalCapacityInBytes + : instance().outboundGlobalReserveLimit; + } + + public int applicationSendQueueReserveEndpointCapacityInBytes() + { + return applicationSendQueueReserveEndpointCapacityInBytes != null ? applicationSendQueueReserveEndpointCapacityInBytes + : DatabaseDescriptor.getInternodeApplicationReceiveQueueReserveEndpointCapacityInBytes(); + } + + public int tcpConnectTimeoutInMS() + { + return tcpConnectTimeoutInMS != null ? tcpConnectTimeoutInMS + : DatabaseDescriptor.getInternodeTcpConnectTimeoutInMS(); + } + + public int tcpUserTimeoutInMS() + { + return tcpUserTimeoutInMS != null ? tcpUserTimeoutInMS + : DatabaseDescriptor.getInternodeTcpUserTimeoutInMS(); + } + + public boolean tcpNoDelay() + { + if (tcpNoDelay != null) + return tcpNoDelay; + + if (isInLocalDC(getEndpointSnitch(), getBroadcastAddressAndPort(), to)) + return INTRADC_TCP_NODELAY; + + return DatabaseDescriptor.getInterDCTcpNoDelay(); + } + + public AcceptVersions acceptVersions(ConnectionCategory category) + { + return acceptVersions != null ? acceptVersions + : category.isStreaming() + ? MessagingService.accept_streaming + : MessagingService.accept_messaging; + } + + public OutboundConnectionSettings withLegacyPortIfNecessary(int messagingVersion) + { + return withConnectTo(maybeWithSecurePort(connectTo(), messagingVersion, withEncryption())); + } + + public InetAddressAndPort connectTo() + { + InetAddressAndPort connectTo = this.connectTo; + if (connectTo == null) + connectTo = SystemKeyspace.getPreferredIP(to); + return connectTo; + } + + public Framing framing(ConnectionCategory category) + { + if (framing != null) + return framing; + + if (category.isStreaming()) + return Framing.UNPROTECTED; + + return shouldCompressConnection(getEndpointSnitch(), getBroadcastAddressAndPort(), to) + ? Framing.LZ4 : Framing.CRC; + } + + // note that connectTo is updated even if specified, in the case of pre40 messaging and using encryption (to update port) + public OutboundConnectionSettings withDefaults(ConnectionCategory category) + { + if (to == null) + throw new IllegalArgumentException(); + + return new OutboundConnectionSettings(authenticator(), to, connectTo(), + encryption(), framing(category), + socketSendBufferSizeInBytes(), applicationSendQueueCapacityInBytes(), + applicationSendQueueReserveEndpointCapacityInBytes(), + applicationSendQueueReserveGlobalCapacityInBytes(), + tcpNoDelay(), flushLowWaterMark, flushHighWaterMark, + tcpConnectTimeoutInMS(), tcpUserTimeoutInMS(), acceptVersions(category), + from(), socketFactory(), callbacks(), debug(), endpointToVersion()); + } + + private static boolean isInLocalDC(IEndpointSnitch snitch, InetAddressAndPort localHost, InetAddressAndPort remoteHost) + { + String remoteDC = snitch.getDatacenter(remoteHost); + String localDC = snitch.getDatacenter(localHost); + return remoteDC != null && remoteDC.equals(localDC); + } + + @VisibleForTesting + static EncryptionOptions defaultEncryptionOptions(InetAddressAndPort endpoint) + { + ServerEncryptionOptions options = DatabaseDescriptor.getInternodeMessagingEncyptionOptions(); + return options.shouldEncrypt(endpoint) ? options : null; + } + + @VisibleForTesting + static boolean shouldCompressConnection(IEndpointSnitch snitch, InetAddressAndPort localHost, InetAddressAndPort remoteHost) + { + return (DatabaseDescriptor.internodeCompression() == Config.InternodeCompression.all) + || ((DatabaseDescriptor.internodeCompression() == Config.InternodeCompression.dc) && !isInLocalDC(snitch, localHost, remoteHost)); + } + + private static InetAddressAndPort maybeWithSecurePort(InetAddressAndPort address, int messagingVersion, boolean isEncrypted) + { + if (!isEncrypted || messagingVersion >= VERSION_40) + return address; + + // if we don't know the version of the peer, assume it is 4.0 (or higher) as the only time is would be lower + // (as in a 3.x version) is during a cluster upgrade (from 3.x to 4.0). In that case the outbound connection will + // unfortunately fail - however the peer should connect to this node (at some point), and once we learn it's version, it'll be + // in versions map. thus, when we attempt to reconnect to that node, we'll have the version and we can get the correct port. + // we will be able to remove this logic at 5.0. + // Also as of 4.0 we will propagate the "regular" port (which will support both SSL and non-SSL) via gossip so + // for SSL and version 4.0 always connect to the gossiped port because if SSL is enabled it should ALWAYS + // listen for SSL on the "regular" port. + return address.withPort(DatabaseDescriptor.getSSLStoragePort()); + } + +} diff --git a/src/java/org/apache/cassandra/net/OutboundConnections.java b/src/java/org/apache/cassandra/net/OutboundConnections.java new file mode 100644 index 000000000000..5f9190b5b671 --- /dev/null +++ b/src/java/org/apache/cassandra/net/OutboundConnections.java @@ -0,0 +1,323 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.cassandra.net; + +import java.nio.channels.ClosedChannelException; +import java.util.List; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; +import java.util.function.Function; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableList; + +import com.carrotsearch.hppc.ObjectObjectOpenHashMap; +import io.netty.util.concurrent.Future; +import org.apache.cassandra.config.Config; +import org.apache.cassandra.gms.Gossiper; +import org.apache.cassandra.locator.InetAddressAndPort; +import org.apache.cassandra.metrics.InternodeOutboundMetrics; +import org.apache.cassandra.utils.concurrent.SimpleCondition; + +import static org.apache.cassandra.net.MessagingService.current_version; +import static org.apache.cassandra.net.ConnectionType.URGENT_MESSAGES; +import static org.apache.cassandra.net.ConnectionType.LARGE_MESSAGES; +import static org.apache.cassandra.net.ConnectionType.SMALL_MESSAGES; + +/** + * Groups a set of outbound connections to a given peer, and routes outgoing messages to the appropriate connection + * (based upon message's type or size). Contains a {@link OutboundConnection} for each of the + * {@link ConnectionType} types. + */ +public class OutboundConnections +{ + @VisibleForTesting + public static final int LARGE_MESSAGE_THRESHOLD = Integer.getInteger(Config.PROPERTY_PREFIX + "otcp_large_message_threshold", 1024 * 64) + - Math.max(Math.max(LegacyLZ4Constants.HEADER_LENGTH, FrameEncoderCrc.HEADER_AND_TRAILER_LENGTH), FrameEncoderLZ4.HEADER_AND_TRAILER_LENGTH); + + private final SimpleCondition metricsReady = new SimpleCondition(); + private volatile InternodeOutboundMetrics metrics; + private final BackPressureState backPressureState; + private final ResourceLimits.Limit reserveCapacity; + + private OutboundConnectionSettings template; + public final OutboundConnection small; + public final OutboundConnection large; + public final OutboundConnection urgent; + + private OutboundConnections(OutboundConnectionSettings template, BackPressureState backPressureState) + { + this.backPressureState = backPressureState; + this.template = template = template.withDefaultReserveLimits(); + reserveCapacity = new ResourceLimits.Concurrent(template.applicationSendQueueReserveEndpointCapacityInBytes); + ResourceLimits.EndpointAndGlobal reserveCapacityInBytes = new ResourceLimits.EndpointAndGlobal(reserveCapacity, template.applicationSendQueueReserveGlobalCapacityInBytes); + this.small = new OutboundConnection(SMALL_MESSAGES, template, reserveCapacityInBytes); + this.large = new OutboundConnection(LARGE_MESSAGES, template, reserveCapacityInBytes); + this.urgent = new OutboundConnection(URGENT_MESSAGES, template, reserveCapacityInBytes); + } + + /** + * Select the appropriate connection for the provided message and use it to send the message. + */ + public void enqueue(Message msg, ConnectionType type) throws ClosedChannelException + { + connectionFor(msg, type).enqueue(msg); + } + + static OutboundConnections tryRegister(ConcurrentMap in, K key, OutboundConnectionSettings settings, BackPressureState backPressureState) + { + OutboundConnections connections = in.get(key); + if (connections == null) + { + connections = new OutboundConnections(settings, backPressureState); + OutboundConnections existing = in.putIfAbsent(key, connections); + + if (existing == null) + { + connections.metrics = new InternodeOutboundMetrics(settings.to, connections); + connections.metricsReady.signalAll(); + } + else + { + connections.metricsReady.signalAll(); + connections.close(false); + connections = existing; + } + } + return connections; + } + + BackPressureState getBackPressureState() + { + return backPressureState; + } + + /** + * Reconnect to the peer using the given {@code addr}. Outstanding messages in each channel will be sent on the + * current channel. Typically this function is used for something like EC2 public IP addresses which need to be used + * for communication between EC2 regions. + * + * @param addr IP Address to use (and prefer) going forward for connecting to the peer + */ + synchronized Future reconnectWithNewIp(InetAddressAndPort addr) + { + template = template.withConnectTo(addr); + return new FutureCombiner( + apply(c -> c.reconnectWith(template)) + ); + } + + /** + * Close the connections permanently + * + * @param flushQueues {@code true} if existing messages in the queue should be sent before closing. + */ + synchronized Future scheduleClose(long time, TimeUnit unit, boolean flushQueues) + { + // immediately release our metrics, so that if we need to re-open immediately we can safely register a new one + releaseMetrics(); + return new FutureCombiner( + apply(c -> c.scheduleClose(time, unit, flushQueues)) + ); + } + + /** + * Close the connections permanently + * + * @param flushQueues {@code true} if existing messages in the queue should be sent before closing. + */ + synchronized Future close(boolean flushQueues) + { + // immediately release our metrics, so that if we need to re-open immediately we can safely register a new one + releaseMetrics(); + return new FutureCombiner( + apply(c -> c.close(flushQueues)) + ); + } + + private void releaseMetrics() + { + try + { + metricsReady.await(); + } + catch (InterruptedException e) + { + throw new RuntimeException(e); + } + + if (metrics != null) + metrics.release(); + } + + /** + * Close each netty channel and its socket + */ + void interrupt() + { + // must return a non-null value for ImmutableList.of() + apply(OutboundConnection::interrupt); + } + + /** + * Apply the given function to each of the connections we are pooling, returning the results as a list + */ + private List apply(Function f) + { + return ImmutableList.of( + f.apply(urgent), f.apply(small), f.apply(large) + ); + } + + @VisibleForTesting + OutboundConnection connectionFor(Message message) + { + return connectionFor(message, null); + } + + private OutboundConnection connectionFor(Message msg, ConnectionType forceConnection) + { + return connectionFor(connectionTypeFor(msg, forceConnection)); + } + + private static ConnectionType connectionTypeFor(Message msg, ConnectionType specifyConnection) + { + if (specifyConnection != null) + return specifyConnection; + + if (msg.verb().priority == Verb.Priority.P0) + return URGENT_MESSAGES; + + return msg.serializedSize(current_version) <= LARGE_MESSAGE_THRESHOLD + ? SMALL_MESSAGES + : LARGE_MESSAGES; + } + + @VisibleForTesting + final OutboundConnection connectionFor(ConnectionType type) + { + switch (type) + { + case SMALL_MESSAGES: + return small; + case LARGE_MESSAGES: + return large; + case URGENT_MESSAGES: + return urgent; + default: + throw new IllegalArgumentException("unsupported connection type: " + type); + } + } + + public long usingReserveBytes() + { + return reserveCapacity.using(); + } + + long expiredCallbacks() + { + return metrics.expiredCallbacks.getCount(); + } + + void incrementExpiredCallbackCount() + { + metrics.expiredCallbacks.mark(); + } + + OutboundConnectionSettings template() + { + return template; + } + + private static class UnusedConnectionMonitor + { + UnusedConnectionMonitor(MessagingService messagingService) + { + this.messagingService = messagingService; + } + + static class Counts + { + final long small, large, urgent; + Counts(long small, long large, long urgent) + { + this.small = small; + this.large = large; + this.urgent = urgent; + } + } + + final MessagingService messagingService; + ObjectObjectOpenHashMap prevEndpointToCounts = new ObjectObjectOpenHashMap<>(); + + private void closeUnusedSinceLastRun() + { + ObjectObjectOpenHashMap curEndpointToCounts = new ObjectObjectOpenHashMap<>(); + for (OutboundConnections connections : messagingService.channelManagers.values()) + { + Counts cur = new Counts( + connections.small.submittedCount(), + connections.large.submittedCount(), + connections.urgent.submittedCount() + ); + curEndpointToCounts.put(connections.template.to, cur); + + Counts prev = prevEndpointToCounts.get(connections.template.to); + if (prev == null) + continue; + + if (cur.small != prev.small && cur.large != prev.large && cur.urgent != prev.urgent) + continue; + + if (cur.small == prev.small && cur.large == prev.large && cur.urgent == prev.urgent + && !Gossiper.instance.isKnownEndpoint(connections.template.to)) + { + // close entirely if no traffic and the endpoint is unknown + messagingService.closeOutboundNow(connections); + continue; + } + + if (cur.small == prev.small) + connections.small.interrupt(); + + if (cur.large == prev.large) + connections.large.interrupt(); + + if (cur.urgent == prev.urgent) + connections.urgent.interrupt(); + } + + prevEndpointToCounts = curEndpointToCounts; + } + } + + static void scheduleUnusedConnectionMonitoring(MessagingService messagingService, ScheduledExecutorService executor, long delay, TimeUnit units) + { + executor.scheduleWithFixedDelay(new UnusedConnectionMonitor(messagingService)::closeUnusedSinceLastRun, 0L, delay, units); + } + + @VisibleForTesting + static OutboundConnections unsafeCreate(OutboundConnectionSettings template, BackPressureState backPressureState) + { + OutboundConnections connections = new OutboundConnections(template, backPressureState); + connections.metricsReady.signalAll(); + return connections; + } + +} diff --git a/src/java/org/apache/cassandra/net/OutboundDebugCallbacks.java b/src/java/org/apache/cassandra/net/OutboundDebugCallbacks.java new file mode 100644 index 000000000000..3b83519fc3fc --- /dev/null +++ b/src/java/org/apache/cassandra/net/OutboundDebugCallbacks.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.cassandra.net; + +interface OutboundDebugCallbacks +{ + static final OutboundDebugCallbacks NONE = new OutboundDebugCallbacks() + { + public void onSendSmallFrame(int messageCount, int payloadSizeInBytes) {} + public void onSentSmallFrame(int messageCount, int payloadSizeInBytes) {} + public void onFailedSmallFrame(int messageCount, int payloadSizeInBytes) {} + public void onConnect(int messagingVersion, OutboundConnectionSettings settings) {} + }; + + /** A complete Frame has been handed to Netty to write to the wire. */ + void onSendSmallFrame(int messageCount, int payloadSizeInBytes); + + /** A complete Frame has been serialized to the wire */ + void onSentSmallFrame(int messageCount, int payloadSizeInBytes); + + /** Failed to send an entire frame due to network problems; presumed to be invoked in same order as onSendSmallFrame */ + void onFailedSmallFrame(int messageCount, int payloadSizeInBytes); + + void onConnect(int messagingVersion, OutboundConnectionSettings settings); +} diff --git a/src/java/org/apache/cassandra/net/ForwardToContainer.java b/src/java/org/apache/cassandra/net/OutboundMessageCallbacks.java similarity index 55% rename from src/java/org/apache/cassandra/net/ForwardToContainer.java rename to src/java/org/apache/cassandra/net/OutboundMessageCallbacks.java index b22eed6b9f91..abf3f4117d0e 100644 --- a/src/java/org/apache/cassandra/net/ForwardToContainer.java +++ b/src/java/org/apache/cassandra/net/OutboundMessageCallbacks.java @@ -15,30 +15,21 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package org.apache.cassandra.net; -import java.io.Serializable; -import java.util.Collection; - -import com.google.common.base.Preconditions; - import org.apache.cassandra.locator.InetAddressAndPort; -/** - * Contains forward to information until it can be serialized as part of a message using a version - * specific serialization - */ -public class ForwardToContainer implements Serializable +interface OutboundMessageCallbacks { - public final Collection targets; - public final int[] messageIds; + /** A message was not enqueued to the link because too many messages are already waiting to send */ + void onOverloaded(Message message, InetAddressAndPort peer); + + /** A message was not serialized to a frame because it had expired */ + void onExpired(Message message, InetAddressAndPort peer); + + /** A message was not fully or successfully serialized to a frame because an exception was thrown */ + void onFailedSerialize(Message message, InetAddressAndPort peer, int messagingVersion, int bytesWrittenToNetwork, Throwable failure); - public ForwardToContainer(Collection targets, - int[] messageIds) - { - Preconditions.checkArgument(targets.size() == messageIds.length); - this.targets = targets; - this.messageIds = messageIds; - } + /** A message was not sent because the connection was forcibly closed */ + void onDiscardOnClose(Message message, InetAddressAndPort peer); } diff --git a/src/java/org/apache/cassandra/net/OutboundMessageQueue.java b/src/java/org/apache/cassandra/net/OutboundMessageQueue.java new file mode 100644 index 000000000000..48c766629d4d --- /dev/null +++ b/src/java/org/apache/cassandra/net/OutboundMessageQueue.java @@ -0,0 +1,484 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.cassandra.net; + +import java.util.Collections; +import java.util.IdentityHashMap; +import java.util.Set; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicLongFieldUpdater; +import java.util.concurrent.atomic.AtomicReference; +import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; +import java.util.function.Consumer; + +import com.google.common.util.concurrent.Uninterruptibles; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import static java.lang.Math.min; +import static org.apache.cassandra.utils.MonotonicClock.approxTime; + +/** + * A composite queue holding messages to be delivered by an {@link OutboundConnection}. + * + * Contains two queues: + * 1. An external MPSC {@link ManyToOneConcurrentLinkedQueue} for producers to enqueue messages onto + * 2. An internal intermediate {@link PrunableArrayQueue} into which the external queue is + * drained with exclusive access and from which actual deliveries happen + * The second, intermediate queue exists to enable efficient in-place pruning of expired messages. + * + * Said pruning will be attempted in several scenarios: + * 1. By callers invoking {@link #add(Message)} - if metadata indicates presence of expired messages + * in the queue, and if exclusive access can be immediately obtained (non-blockingly) + * 2. By {@link OutboundConnection}, periodically, while disconnected + * 3. As an optimisation, in an attempt to free up endpoint capacity on {@link OutboundConnection#enqueue(Message)} + * if current endpoint reserve was insufficient + */ +class OutboundMessageQueue +{ + private static final Logger logger = LoggerFactory.getLogger(OutboundMessageQueue.class); + + interface MessageConsumer + { + boolean accept(Message message) throws Produces; + } + + private final MessageConsumer onExpired; + + private final ManyToOneConcurrentLinkedQueue> externalQueue = new ManyToOneConcurrentLinkedQueue<>(); + private final PrunableArrayQueue> internalQueue = new PrunableArrayQueue<>(256); + + private volatile long earliestExpiresAt = Long.MAX_VALUE; + private static final AtomicLongFieldUpdater earliestExpiresAtUpdater = + AtomicLongFieldUpdater.newUpdater(OutboundMessageQueue.class, "earliestExpiresAt"); + + OutboundMessageQueue(MessageConsumer onExpired) + { + this.onExpired = onExpired; + } + + /** + * Add the provided message to the queue. Always succeeds. + */ + void add(Message m) + { + maybePruneExpired(); + externalQueue.offer(m); + maybeUpdateMinimumExpiryTime(m.expiresAtNanos()); + } + + /** + * Try to obtain the lock; if this fails, a callback will be registered to be invoked when + * the lock is relinquished. + * + * This callback will run WITHOUT ownership of the lock, so must re-obtain the lock. + * + * @return null if failed to obtain the lock + */ + WithLock lockOrCallback(long nowNanos, Runnable callbackIfDeferred) + { + if (!lockOrCallback(callbackIfDeferred)) + return null; + + return new WithLock(nowNanos); + } + + /** + * Try to obtain the lock. If successful, invoke the provided consumer immediately, otherwise + * register it to be invoked when the lock is relinquished. + */ + void runEventually(Consumer runEventually) + { + try (WithLock withLock = lockOrCallback(approxTime.now(), () -> runEventually(runEventually))) + { + if (withLock != null) + runEventually.accept(withLock); + } + } + + /** + * If succeeds to obtain the lock, polls the queue, otherwise registers the provided callback + * to be invoked when the lock is relinquished. + * + * May return null when the queue is non-empty - if the lock could not be acquired. + */ + Message tryPoll(long nowNanos, Runnable elseIfDeferred) + { + try (WithLock withLock = lockOrCallback(nowNanos, elseIfDeferred)) + { + if (withLock == null) + return null; + + return withLock.poll(); + } + } + + class WithLock implements AutoCloseable + { + private final long nowNanos; + + private WithLock(long nowNanos) + { + this.nowNanos = nowNanos; + earliestExpiresAt = Long.MAX_VALUE; + externalQueue.drain(internalQueue::offer); + } + + Message poll() + { + Message m; + while (null != (m = internalQueue.poll())) + { + if (shouldSend(m, nowNanos)) + break; + + onExpired.accept(m); + } + + return m; + } + + void removeHead(Message expectHead) + { + assert expectHead == internalQueue.peek(); + internalQueue.poll(); + } + + Message peek() + { + Message m; + while (null != (m = internalQueue.peek())) + { + if (shouldSend(m, nowNanos)) + break; + + internalQueue.poll(); + onExpired.accept(m); + } + + return m; + } + + void consume(Consumer> consumer) + { + Message m; + while (null != (m = poll())) + consumer.accept(m); + } + + @Override + public void close() + { + pruneInternalQueueWithLock(nowNanos); + unlock(); + } + } + + /** + * Call periodically if cannot expect to promptly invoke consume() + */ + boolean maybePruneExpired() + { + return maybePruneExpired(approxTime.now()); + } + + private boolean maybePruneExpired(long nowNanos) + { + if (approxTime.isAfter(nowNanos, earliestExpiresAt)) + return tryRun(() -> pruneWithLock(nowNanos)); + return false; + } + + private void maybeUpdateMinimumExpiryTime(long newTime) + { + if (newTime < earliestExpiresAt) + earliestExpiresAtUpdater.accumulateAndGet(this, newTime, Math::min); + } + + /* + * Drain external queue into the internal one and prune the latter in-place. + */ + private void pruneWithLock(long nowNanos) + { + earliestExpiresAt = Long.MAX_VALUE; + externalQueue.drain(internalQueue::offer); + pruneInternalQueueWithLock(nowNanos); + } + + /* + * Prune the internal queue in-place. + */ + private void pruneInternalQueueWithLock(long nowNanos) + { + class Pruner implements PrunableArrayQueue.Pruner> + { + private long earliestExpiresAt = Long.MAX_VALUE; + + public boolean shouldPrune(Message message) + { + return !shouldSend(message, nowNanos); + } + + public void onPruned(Message message) + { + onExpired.accept(message); + } + + public void onKept(Message message) + { + earliestExpiresAt = min(message.expiresAtNanos(), earliestExpiresAt); + } + } + + Pruner pruner = new Pruner(); + internalQueue.prune(pruner); + + maybeUpdateMinimumExpiryTime(pruner.earliestExpiresAt); + } + + private static class Locked implements Runnable + { + final Runnable run; + final Locked next; + private Locked(Runnable run, Locked next) + { + this.run = run; + this.next = next; + } + + Locked andThen(Runnable next) + { + return new Locked(next, this); + } + + public void run() + { + Locked cur = this; + while (cur != null) + { + try + { + cur.run.run(); + } + catch (Throwable t) + { + logger.error("Unexpected error when executing deferred lock-intending functions", t); + } + cur = cur.next; + } + } + } + + private static final Locked LOCKED = new Locked(() -> {}, null); + + private volatile Locked locked = null; + private static final AtomicReferenceFieldUpdater lockedUpdater = + AtomicReferenceFieldUpdater.newUpdater(OutboundMessageQueue.class, Locked.class, "locked"); + + /** + * Run runOnceLocked either immediately in the calling thread if we can obtain the lock, or ask the lock's current + * owner attempt to run it when the lock is released. This may be passed between a sequence of owners, as the present + * owner releases the lock before trying to acquire it again and execute the task. + */ + private void runEventually(Runnable runEventually) + { + if (!lockOrCallback(() -> runEventually(runEventually))) + return; + + try + { + runEventually.run(); + } + finally + { + unlock(); + } + } + + /** + * If we can immediately obtain the lock, execute runIfLocked and return true; + * otherwise do nothing and return false. + */ + private boolean tryRun(Runnable runIfAvailable) + { + if (!tryLock()) + return false; + + try + { + runIfAvailable.run(); + return true; + } + finally + { + unlock(); + } + } + + /** + * @return true iff the caller now owns the lock + */ + private boolean tryLock() + { + return locked == null && lockedUpdater.compareAndSet(this, null, LOCKED); + } + + /** + * Try to obtain the lock; if this fails, a callback will be registered to be invoked when the lock is relinquished. + * This callback will run WITHOUT ownership of the lock, so must re-obtain the lock. + * + * @return true iff the caller now owns the lock + */ + private boolean lockOrCallback(Runnable callbackWhenAvailable) + { + if (callbackWhenAvailable == null) + return tryLock(); + + while (true) + { + Locked current = locked; + if (current == null && lockedUpdater.compareAndSet(this, null, LOCKED)) + return true; + else if (current != null && lockedUpdater.compareAndSet(this, current, current.andThen(callbackWhenAvailable))) + return false; + } + } + + private void unlock() + { + Locked locked = lockedUpdater.getAndSet(this, null); + locked.run(); + } + + + /** + * While removal happens extremely infrequently, it seems possible for many to still interleave with a connection + * being closed, as experimentally we have encountered enough pending removes to overflow the Locked call stack + * (prior to making its evaluation iterative). + * + * While the stack can no longer be exhausted, this suggests a high potential cost for evaluating all removals, + * so to ensure system stability we aggregate all pending removes into a single shared object that evaluate + * together with only a single lock acquisition. + */ + private volatile RemoveRunner removeRunner = null; + private static final AtomicReferenceFieldUpdater removeRunnerUpdater = + AtomicReferenceFieldUpdater.newUpdater(OutboundMessageQueue.class, RemoveRunner.class, "removeRunner"); + + static class Remove + { + final Message message; + final Remove next; + + Remove(Message message, Remove next) + { + this.message = message; + this.next = next; + } + } + + private class RemoveRunner extends AtomicReference implements Runnable + { + final CountDownLatch done = new CountDownLatch(1); + final Set> removed = Collections.newSetFromMap(new IdentityHashMap<>()); + + RemoveRunner() { super(new Remove(null, null)); } + + boolean undo(Message message) + { + return null != updateAndGet(prev -> prev == null ? null : new Remove(message, prev)); + } + + public void run() + { + Set> remove = Collections.newSetFromMap(new IdentityHashMap<>()); + removeRunner = null; + Remove undo = getAndSet(null); + while (undo.message != null) + { + remove.add(undo.message); + undo = undo.next; + } + + class Remover implements PrunableArrayQueue.Pruner> + { + private long earliestExpiresAt = Long.MAX_VALUE; + + @Override + public boolean shouldPrune(Message message) + { + return remove.contains(message); + } + + @Override + public void onPruned(Message message) + { + removed.add(message); + } + + @Override + public void onKept(Message message) + { + earliestExpiresAt = min(message.expiresAtNanos(), earliestExpiresAt); + } + } + + Remover remover = new Remover(); + earliestExpiresAt = Long.MAX_VALUE; + externalQueue.drain(internalQueue::offer); + internalQueue.prune(remover); + maybeUpdateMinimumExpiryTime(remover.earliestExpiresAt); + done.countDown(); + } + } + + /** + * Remove the provided Message from the queue, if present. + * + * WARNING: This is a blocking call. + */ + boolean remove(Message remove) + { + if (remove == null) + throw new NullPointerException(); + + RemoveRunner runner; + while (true) + { + runner = removeRunner; + if (runner != null && runner.undo(remove)) + break; + + if (runner == null && removeRunnerUpdater.compareAndSet(this, null, runner = new RemoveRunner())) + { + runner.undo(remove); + runEventually(runner); + break; + } + } + + //noinspection UnstableApiUsage + Uninterruptibles.awaitUninterruptibly(runner.done); + return runner.removed.contains(remove); + } + + private static boolean shouldSend(Message m, long nowNanos) + { + return !approxTime.isAfter(nowNanos, m.expiresAtNanos()); + } +} diff --git a/src/java/org/apache/cassandra/net/OutboundSink.java b/src/java/org/apache/cassandra/net/OutboundSink.java new file mode 100644 index 000000000000..d19b3e2107a2 --- /dev/null +++ b/src/java/org/apache/cassandra/net/OutboundSink.java @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.cassandra.net; + +import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; +import java.util.function.BiPredicate; + +import org.apache.cassandra.locator.InetAddressAndPort; + +/** + * A message sink that all outbound messages go through. + * + * Default sink {@link Sink} used by {@link MessagingService} is MessagingService#doSend(), which proceeds to + * send messages over the network, but it can be overridden to filter out certain messages, record the fact + * of attempted delivery, or delay they delivery. + * + * This facility is most useful for test code. + */ +public class OutboundSink +{ + public interface Sink + { + void accept(Message message, InetAddressAndPort to, ConnectionType connectionType); + } + + private static class Filtered implements Sink + { + final BiPredicate, InetAddressAndPort> condition; + final Sink next; + + private Filtered(BiPredicate, InetAddressAndPort> condition, Sink next) + { + this.condition = condition; + this.next = next; + } + + public void accept(Message message, InetAddressAndPort to, ConnectionType connectionType) + { + if (condition.test(message, to)) + next.accept(message, to, connectionType); + } + } + + private volatile Sink sink; + private static final AtomicReferenceFieldUpdater sinkUpdater + = AtomicReferenceFieldUpdater.newUpdater(OutboundSink.class, Sink.class, "sink"); + + OutboundSink(Sink sink) + { + this.sink = sink; + } + + public void accept(Message message, InetAddressAndPort to, ConnectionType connectionType) + { + sink.accept(message, to, connectionType); + } + + public void add(BiPredicate, InetAddressAndPort> allow) + { + sinkUpdater.updateAndGet(this, sink -> new Filtered(allow, sink)); + } + + public void remove(BiPredicate, InetAddressAndPort> allow) + { + sinkUpdater.updateAndGet(this, sink -> without(sink, allow)); + } + + public void clear() + { + sinkUpdater.updateAndGet(this, OutboundSink::clear); + } + + private static Sink clear(Sink sink) + { + while (sink instanceof OutboundSink.Filtered) + sink = ((OutboundSink.Filtered) sink).next; + return sink; + } + + private static Sink without(Sink sink, BiPredicate, InetAddressAndPort> condition) + { + if (!(sink instanceof Filtered)) + return sink; + + Filtered filtered = (Filtered) sink; + Sink next = without(filtered.next, condition); + return condition.equals(filtered.condition) ? next + : next == filtered.next + ? sink + : new Filtered(filtered.condition, next); + } + +} diff --git a/src/java/org/apache/cassandra/net/ParamType.java b/src/java/org/apache/cassandra/net/ParamType.java new file mode 100644 index 000000000000..65723487e5da --- /dev/null +++ b/src/java/org/apache/cassandra/net/ParamType.java @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.cassandra.net; + +import java.util.HashMap; +import java.util.Map; + +import javax.annotation.Nullable; + +import org.apache.cassandra.exceptions.RequestFailureReason; +import org.apache.cassandra.io.IVersionedSerializer; +import org.apache.cassandra.locator.InetAddressAndPort; +import org.apache.cassandra.tracing.Tracing; +import org.apache.cassandra.utils.UUIDSerializer; + +import static java.lang.Math.max; +import static org.apache.cassandra.locator.InetAddressAndPort.Serializer.inetAddressAndPortSerializer; + +/** + * Type names and serializers for various parameters that can be put in {@link Message} params map. + * + * It should be safe to add new params without bumping messaging version - {@link Message} serializer + * will skip over any params it doesn't recognise. + * + * Please don't add boolean params here. Extend and use {@link MessageFlag} instead. + */ +public enum ParamType +{ + FORWARD_TO (0, "FORWARD_TO", ForwardingInfo.serializer), + RESPOND_TO (1, "FORWARD_FROM", inetAddressAndPortSerializer), + + @Deprecated + FAILURE_RESPONSE (2, "FAIL", LegacyFlag.serializer), + @Deprecated + FAILURE_REASON (3, "FAIL_REASON", RequestFailureReason.serializer), + @Deprecated + FAILURE_CALLBACK (4, "CAL_BAC", LegacyFlag.serializer), + + TRACE_SESSION (5, "TraceSession", UUIDSerializer.serializer), + TRACE_TYPE (6, "TraceType", Tracing.traceTypeSerializer), + + @Deprecated + TRACK_REPAIRED_DATA (7, "TrackRepaired", LegacyFlag.serializer); + + final int id; + @Deprecated final String legacyAlias; // pre-4.0 we used to serialize entire param name string + final IVersionedSerializer serializer; + + ParamType(int id, String legacyAlias, IVersionedSerializer serializer) + { + if (id < 0) + throw new IllegalArgumentException("ParamType id must be non-negative"); + + this.id = id; + this.legacyAlias = legacyAlias; + this.serializer = serializer; + } + + private static final ParamType[] idToTypeMap; + private static final Map aliasToTypeMap; + + static + { + ParamType[] types = values(); + + int max = -1; + for (ParamType t : types) + max = max(t.id, max); + + ParamType[] idMap = new ParamType[max + 1]; + Map aliasMap = new HashMap<>(); + + for (ParamType type : types) + { + if (idMap[type.id] != null) + throw new RuntimeException("Two ParamType-s that map to the same id: " + type.id); + idMap[type.id] = type; + + if (aliasMap.put(type.legacyAlias, type) != null) + throw new RuntimeException("Two ParamType-s that map to the same legacy alias: " + type.legacyAlias); + } + + idToTypeMap = idMap; + aliasToTypeMap = aliasMap; + } + + @Nullable + static ParamType lookUpById(int id) + { + if (id < 0) + throw new IllegalArgumentException("ParamType id must be non-negative (got " + id + ')'); + + return id < idToTypeMap.length ? idToTypeMap[id] : null; + } + + @Nullable + static ParamType lookUpByAlias(String alias) + { + return aliasToTypeMap.get(alias); + } +} diff --git a/src/java/org/apache/cassandra/net/ParameterType.java b/src/java/org/apache/cassandra/net/ParameterType.java deleted file mode 100644 index b7a88a8dd93c..000000000000 --- a/src/java/org/apache/cassandra/net/ParameterType.java +++ /dev/null @@ -1,70 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.cassandra.net; - -import java.util.Map; - -import com.google.common.collect.ImmutableMap; - -import org.apache.cassandra.io.DummyByteVersionedSerializer; -import org.apache.cassandra.io.IVersionedSerializer; -import org.apache.cassandra.io.ShortVersionedSerializer; -import org.apache.cassandra.tracing.Tracing; -import org.apache.cassandra.utils.UUIDSerializer; - -/** - * Type names and serializers for various parameters that - */ -public enum ParameterType -{ - FORWARD_TO("FORWARD_TO", ForwardToSerializer.instance), - FORWARD_FROM("FORWARD_FROM", CompactEndpointSerializationHelper.instance), - FAILURE_RESPONSE("FAIL", DummyByteVersionedSerializer.instance), - FAILURE_REASON("FAIL_REASON", ShortVersionedSerializer.instance), - FAILURE_CALLBACK("CAL_BAC", DummyByteVersionedSerializer.instance), - TRACE_SESSION("TraceSession", UUIDSerializer.serializer), - TRACE_TYPE("TraceType", Tracing.traceTypeSerializer), - TRACK_REPAIRED_DATA("TrackRepaired", DummyByteVersionedSerializer.instance); - - public static final Map byName; - public final String key; - public final IVersionedSerializer serializer; - - static - { - ImmutableMap.Builder builder = ImmutableMap.builder(); - for (ParameterType type : values()) - { - builder.put(type.key, type); - } - byName = builder.build(); - } - - ParameterType(String key, IVersionedSerializer serializer) - { - this.key = key; - this.serializer = serializer; - } - - public String key() - { - return key; - } - -} diff --git a/src/java/org/apache/cassandra/net/PingMessage.java b/src/java/org/apache/cassandra/net/PingMessage.java deleted file mode 100644 index 4a19f22b112d..000000000000 --- a/src/java/org/apache/cassandra/net/PingMessage.java +++ /dev/null @@ -1,82 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.cassandra.net; - -import java.io.IOException; - -import org.apache.cassandra.hints.HintResponse; -import org.apache.cassandra.io.IVersionedSerializer; -import org.apache.cassandra.io.util.DataInputPlus; -import org.apache.cassandra.io.util.DataOutputPlus; -import org.apache.cassandra.net.async.OutboundConnectionIdentifier; -import org.apache.cassandra.net.async.OutboundConnectionIdentifier.ConnectionType; - -/** - * Conceptually the same as {@link org.apache.cassandra.gms.EchoMessage}, but indicates to the recipient which - * {@link ConnectionType} should be used for the response. - */ -public class PingMessage -{ - public static IVersionedSerializer serializer = new PingMessageSerializer(); - - public static final PingMessage smallChannelMessage = new PingMessage(ConnectionType.SMALL_MESSAGE); - public static final PingMessage largeChannelMessage = new PingMessage(ConnectionType.LARGE_MESSAGE); - public static final PingMessage gossipChannelMessage = new PingMessage(ConnectionType.GOSSIP); - - public final ConnectionType connectionType; - - public PingMessage(ConnectionType connectionType) - { - this.connectionType = connectionType; - } - - public static class PingMessageSerializer implements IVersionedSerializer - { - public void serialize(PingMessage t, DataOutputPlus out, int version) throws IOException - { - out.writeByte(t.connectionType.getId()); - } - - public PingMessage deserialize(DataInputPlus in, int version) throws IOException - { - ConnectionType connectionType = ConnectionType.fromId(in.readByte()); - - // if we ever create a new connection type, then during a rolling upgrade, the old nodes won't know about - // the new connection type (as it won't recognize the id), so just default to the small message type. - if (connectionType == null) - connectionType = ConnectionType.SMALL_MESSAGE; - - switch (connectionType) - { - case LARGE_MESSAGE: - return largeChannelMessage; - case GOSSIP: - return gossipChannelMessage; - case SMALL_MESSAGE: - default: - return smallChannelMessage; - } - } - - public long serializedSize(PingMessage t, int version) - { - return 1; - } - } -} diff --git a/src/java/org/apache/cassandra/net/PingRequest.java b/src/java/org/apache/cassandra/net/PingRequest.java new file mode 100644 index 000000000000..c02bd8099d4c --- /dev/null +++ b/src/java/org/apache/cassandra/net/PingRequest.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.cassandra.net; + +import java.io.IOException; + +import org.apache.cassandra.io.IVersionedSerializer; +import org.apache.cassandra.io.util.DataInputPlus; +import org.apache.cassandra.io.util.DataOutputPlus; + +import static org.apache.cassandra.net.ConnectionType.URGENT_MESSAGES; +import static org.apache.cassandra.net.ConnectionType.SMALL_MESSAGES; +import static org.apache.cassandra.net.ConnectionType.LARGE_MESSAGES; + +/** + * Indicates to the recipient which {@link ConnectionType} should be used for the response. + */ +class PingRequest +{ + static final PingRequest forUrgent = new PingRequest(URGENT_MESSAGES); + static final PingRequest forSmall = new PingRequest(SMALL_MESSAGES); + static final PingRequest forLarge = new PingRequest(LARGE_MESSAGES); + + final ConnectionType connectionType; + + private PingRequest(ConnectionType connectionType) + { + this.connectionType = connectionType; + } + + static IVersionedSerializer serializer = new IVersionedSerializer() + { + public void serialize(PingRequest t, DataOutputPlus out, int version) throws IOException + { + out.writeByte(t.connectionType.id); + } + + public PingRequest deserialize(DataInputPlus in, int version) throws IOException + { + ConnectionType type = ConnectionType.fromId(in.readByte()); + + switch (type) + { + case URGENT_MESSAGES: return forUrgent; + case SMALL_MESSAGES: return forSmall; + case LARGE_MESSAGES: return forLarge; + } + + throw new IllegalStateException(); + } + + public long serializedSize(PingRequest t, int version) + { + return 1; + } + }; +} diff --git a/src/java/org/apache/cassandra/net/PingVerbHandler.java b/src/java/org/apache/cassandra/net/PingVerbHandler.java index d959b919bd62..a70cddc293b5 100644 --- a/src/java/org/apache/cassandra/net/PingVerbHandler.java +++ b/src/java/org/apache/cassandra/net/PingVerbHandler.java @@ -15,17 +15,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package org.apache.cassandra.net; -public class PingVerbHandler implements IVerbHandler +class PingVerbHandler implements IVerbHandler { + static final PingVerbHandler instance = new PingVerbHandler(); + @Override - public void doVerb(MessageIn message, int id) + public void doVerb(Message message) { - MessageOut msg = new MessageOut<>(MessagingService.Verb.REQUEST_RESPONSE, PongMessage.instance, - PongMessage.serializer, - message.payload.connectionType); - MessagingService.instance().sendReply(msg, id, message.from); + MessagingService.instance().send(message.emptyResponse(), message.from(), message.payload.connectionType); } } diff --git a/src/java/org/apache/cassandra/net/PrunableArrayQueue.java b/src/java/org/apache/cassandra/net/PrunableArrayQueue.java new file mode 100644 index 000000000000..1fca43ca2b33 --- /dev/null +++ b/src/java/org/apache/cassandra/net/PrunableArrayQueue.java @@ -0,0 +1,172 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.cassandra.net; + +import java.util.function.Predicate; + +/** + * A growing array-based queue that allows efficient bulk in-place removal. + * + * Can think of this queue as if it were an {@link java.util.ArrayDeque} with {@link #prune(Pruner)} method - an efficient + * way to prune the queue in-place that is more expressive, and faster than {@link java.util.ArrayDeque#removeIf(Predicate)}. + * + * The latter has to perform O(n*n) shifts, whereas {@link #prune(Pruner)} only needs O(n) shifts at worst. + */ +final class PrunableArrayQueue +{ + public interface Pruner + { + /** + * @return whether the element should be pruned + * if {@code true}, the element will be removed from the queue, and {@link #onPruned(Object)} will be invoked, + * if {@code false}, the element will be kept, and {@link #onKept(Object)} will be invoked. + */ + boolean shouldPrune(E e); + + void onPruned(E e); + void onKept(E e); + } + + private int capacity; + private E[] buffer; + + /* + * mask = capacity - 1; + * since capacity is a power of 2, value % capacity == value & (capacity - 1) == value & mask + */ + private int mask; + + private int head = 0; + private int tail = 0; + + @SuppressWarnings("unchecked") + PrunableArrayQueue(int requestedCapacity) + { + capacity = Math.max(8, findNextPositivePowerOfTwo(requestedCapacity)); + mask = capacity - 1; + buffer = (E[]) new Object[capacity]; + } + + @SuppressWarnings("UnusedReturnValue") + boolean offer(E e) + { + buffer[tail] = e; + if ((tail = (tail + 1) & mask) == head) + doubleCapacity(); + return true; + } + + E peek() + { + return buffer[head]; + } + + E poll() + { + E result = buffer[head]; + if (null == result) + return null; + + buffer[head] = null; + head = (head + 1) & mask; + + return result; + } + + int size() + { + return (tail - head) & mask; + } + + boolean isEmpty() + { + return head == tail; + } + + /** + * Prunes the queue using the specified {@link Pruner} + * + * @return count of removed elements. + */ + int prune(Pruner pruner) + { + E e; + int removed = 0; + + try + { + int size = size(); + for (int i = 0; i < size; i++) + { + /* + * We start at the tail and work backwards to minimise the number of copies + * as we expect to primarily prune from the front. + */ + int k = (tail - 1 - i) & mask; + e = buffer[k]; + + if (pruner.shouldPrune(e)) + { + buffer[k] = null; + removed++; + pruner.onPruned(e); + } + else + { + if (removed > 0) + { + buffer[(k + removed) & mask] = e; + buffer[k] = null; + } + pruner.onKept(e); + } + } + } + finally + { + head = (head + removed) & mask; + } + + return removed; + } + + @SuppressWarnings("unchecked") + private void doubleCapacity() + { + assert head == tail; + + int newCapacity = capacity << 1; + E[] newBuffer = (E[]) new Object[newCapacity]; + + int headPortionLen = capacity - head; + System.arraycopy(buffer, head, newBuffer, 0, headPortionLen); + System.arraycopy(buffer, 0, newBuffer, headPortionLen, tail); + + head = 0; + tail = capacity; + + capacity = newCapacity; + mask = newCapacity - 1; + buffer = newBuffer; + } + + private static int findNextPositivePowerOfTwo(int value) + { + return 1 << (32 - Integer.numberOfLeadingZeros(value - 1)); + } +} diff --git a/src/java/org/apache/cassandra/net/RateBasedBackPressure.java b/src/java/org/apache/cassandra/net/RateBasedBackPressure.java index 22cec655638d..02d8cce35772 100644 --- a/src/java/org/apache/cassandra/net/RateBasedBackPressure.java +++ b/src/java/org/apache/cassandra/net/RateBasedBackPressure.java @@ -39,6 +39,8 @@ import org.apache.cassandra.utils.TimeSource; import org.apache.cassandra.utils.concurrent.IntervalLock; +import static java.util.concurrent.TimeUnit.MILLISECONDS; + /** * Back-pressure algorithm based on rate limiting according to the ratio between incoming and outgoing rates, computed * over a sliding time window with size equal to write RPC timeout. @@ -84,7 +86,7 @@ public static ParameterizedClass withDefaultParams() public RateBasedBackPressure(Map args) { - this(args, new SystemTimeSource(), DatabaseDescriptor.getWriteRpcTimeout()); + this(args, new SystemTimeSource(), DatabaseDescriptor.getWriteRpcTimeout(MILLISECONDS)); } @VisibleForTesting diff --git a/src/java/org/apache/cassandra/net/RateBasedBackPressureState.java b/src/java/org/apache/cassandra/net/RateBasedBackPressureState.java index 9df056e9454a..a15087493c0b 100644 --- a/src/java/org/apache/cassandra/net/RateBasedBackPressureState.java +++ b/src/java/org/apache/cassandra/net/RateBasedBackPressureState.java @@ -61,7 +61,7 @@ class RateBasedBackPressureState extends IntervalLock implements BackPressureSta } @Override - public void onMessageSent(MessageOut message) {} + public void onMessageSent(Message message) {} @Override public void onResponseReceived() diff --git a/src/java/org/apache/cassandra/net/IAsyncCallback.java b/src/java/org/apache/cassandra/net/RequestCallback.java similarity index 61% rename from src/java/org/apache/cassandra/net/IAsyncCallback.java rename to src/java/org/apache/cassandra/net/RequestCallback.java index ceaf0721963f..9ed3a4b296be 100644 --- a/src/java/org/apache/cassandra/net/IAsyncCallback.java +++ b/src/java/org/apache/cassandra/net/RequestCallback.java @@ -17,24 +17,45 @@ */ package org.apache.cassandra.net; +import org.apache.cassandra.exceptions.RequestFailureReason; +import org.apache.cassandra.locator.InetAddressAndPort; + /** - * implementors of IAsyncCallback need to make sure that any public methods - * are threadsafe with respect to response() being called from the message + * implementors of {@link RequestCallback} need to make sure that any public methods + * are threadsafe with respect to {@link #onResponse} being called from the message * service. In particular, if any shared state is referenced, making * response alone synchronized will not suffice. */ -public interface IAsyncCallback +public interface RequestCallback { /** * @param msg response received. */ - void response(MessageIn msg); + void onResponse(Message msg); + + /** + * Called when there is an exception on the remote node or timeout happens + */ + default void onFailure(InetAddressAndPort from, RequestFailureReason failureReason) + { + } + + /** + * @return true if the callback should be invoked on failure + */ + default boolean invokeOnFailure() + { + return false; + } /** * @return true if this callback is on the read path and its latency should be * given as input to the dynamic snitch. */ - boolean isLatencyForSnitch(); + default boolean trackLatencyForSnitch() + { + return false; + } default boolean supportsBackPressure() { diff --git a/src/java/org/apache/cassandra/net/RequestCallbacks.java b/src/java/org/apache/cassandra/net/RequestCallbacks.java new file mode 100644 index 000000000000..fd3a09600114 --- /dev/null +++ b/src/java/org/apache/cassandra/net/RequestCallbacks.java @@ -0,0 +1,382 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.cassandra.net; + +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeoutException; + +import javax.annotation.Nullable; + +import com.google.common.annotations.VisibleForTesting; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.cassandra.concurrent.DebuggableScheduledThreadPoolExecutor; +import org.apache.cassandra.concurrent.StageManager; +import org.apache.cassandra.config.DatabaseDescriptor; +import org.apache.cassandra.db.ConsistencyLevel; +import org.apache.cassandra.db.Mutation; +import org.apache.cassandra.exceptions.RequestFailureReason; +import org.apache.cassandra.io.IVersionedAsymmetricSerializer; +import org.apache.cassandra.locator.InetAddressAndPort; +import org.apache.cassandra.locator.Replica; +import org.apache.cassandra.metrics.InternodeOutboundMetrics; +import org.apache.cassandra.service.AbstractWriteResponseHandler; +import org.apache.cassandra.service.StorageProxy; +import org.apache.cassandra.service.paxos.Commit; +import org.apache.cassandra.utils.FBUtilities; + +import static java.lang.String.format; +import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static java.util.concurrent.TimeUnit.NANOSECONDS; +import static org.apache.cassandra.concurrent.Stage.INTERNAL_RESPONSE; +import static org.apache.cassandra.utils.MonotonicClock.preciseTime; + +/** + * An expiring map of request callbacks. + * + * Used to match response (id, peer) pairs to corresponding {@link RequestCallback}s, or, if said responses + * don't arrive in a timely manner (within verb's timeout), to expire the callbacks. + * + * Since we reuse the same request id for multiple messages now, the map is keyed by (id, peer) tuples + * rather than just id as it used to before 4.0. + */ +public class RequestCallbacks implements OutboundMessageCallbacks +{ + private static final Logger logger = LoggerFactory.getLogger(RequestCallbacks.class); + + private final MessagingService messagingService; + private final ScheduledExecutorService executor = new DebuggableScheduledThreadPoolExecutor("Callback-Map-Reaper"); + private final ConcurrentMap callbacks = new ConcurrentHashMap<>(); + + RequestCallbacks(MessagingService messagingService) + { + this.messagingService = messagingService; + + long expirationInterval = DatabaseDescriptor.getMinRpcTimeout(NANOSECONDS) / 2; + executor.scheduleWithFixedDelay(this::expire, expirationInterval, expirationInterval, NANOSECONDS); + } + + /** + * @return the registered {@link CallbackInfo} for this id and peer, or {@code null} if unset or expired. + */ + @Nullable + CallbackInfo get(long id, InetAddressAndPort peer) + { + return callbacks.get(key(id, peer)); + } + + /** + * Remove and return the {@link CallbackInfo} associated with given id and peer, if known. + */ + @Nullable + CallbackInfo remove(long id, InetAddressAndPort peer) + { + return callbacks.remove(key(id, peer)); + } + + /** + * Register the provided {@link RequestCallback}, inferring expiry and id from the provided {@link Message}. + */ + void addWithExpiration(RequestCallback cb, Message message, InetAddressAndPort to) + { + // mutations need to call the overload with a ConsistencyLevel + assert message.verb() != Verb.MUTATION_REQ && message.verb() != Verb.COUNTER_MUTATION_REQ && message.verb() != Verb.PAXOS_COMMIT_REQ; + CallbackInfo previous = callbacks.put(key(message.id(), to), new CallbackInfo(message, to, cb)); + assert previous == null : format("Callback already exists for id %d/%s! (%s)", message.id(), to, previous); + } + + // FIXME: shouldn't need a special overload for writes; hinting should be part of AbstractWriteResponseHandler + public void addWithExpiration(AbstractWriteResponseHandler cb, + Message message, + Replica to, + ConsistencyLevel consistencyLevel, + boolean allowHints) + { + assert message.verb() == Verb.MUTATION_REQ || message.verb() == Verb.COUNTER_MUTATION_REQ || message.verb() == Verb.PAXOS_COMMIT_REQ; + CallbackInfo previous = callbacks.put(key(message.id(), to.endpoint()), new WriteCallbackInfo(message, to, cb, consistencyLevel, allowHints)); + assert previous == null : format("Callback already exists for id %d/%s! (%s)", message.id(), to.endpoint(), previous); + } + + IVersionedAsymmetricSerializer responseSerializer(long id, InetAddressAndPort peer) + { + CallbackInfo info = get(id, peer); + return info == null ? null : info.responseVerb.serializer(); + } + + @VisibleForTesting + public void removeAndRespond(long id, InetAddressAndPort peer, Message message) + { + CallbackInfo ci = remove(id, peer); + if (null != ci) ci.callback.onResponse(message); + } + + private void removeAndExpire(long id, InetAddressAndPort peer) + { + CallbackInfo ci = remove(id, peer); + if (null != ci) onExpired(ci); + } + + private void expire() + { + long start = preciseTime.now(); + int n = 0; + for (Map.Entry entry : callbacks.entrySet()) + { + if (entry.getValue().isReadyToDieAt(start)) + { + if (callbacks.remove(entry.getKey(), entry.getValue())) + { + n++; + onExpired(entry.getValue()); + } + } + } + logger.trace("Expired {} entries", n); + } + + private void forceExpire() + { + for (Map.Entry entry : callbacks.entrySet()) + if (callbacks.remove(entry.getKey(), entry.getValue())) + onExpired(entry.getValue()); + } + + private void onExpired(CallbackInfo info) + { + messagingService.latencySubscribers.maybeAdd(info.callback, info.peer, info.timeout(), NANOSECONDS); + + InternodeOutboundMetrics.totalExpiredCallbacks.mark(); + messagingService.markExpiredCallback(info.peer); + + if (info.callback.supportsBackPressure()) + messagingService.updateBackPressureOnReceive(info.peer, info.callback, true); + + if (info.invokeOnFailure()) + StageManager.getStage(INTERNAL_RESPONSE).submit(() -> info.callback.onFailure(info.peer, RequestFailureReason.TIMEOUT)); + + // FIXME: this has never belonged here, should be part of onFailure() in AbstractWriteResponseHandler + if (info.shouldHint()) + { + WriteCallbackInfo writeCallbackInfo = ((WriteCallbackInfo) info); + Mutation mutation = writeCallbackInfo.mutation(); + StorageProxy.submitHint(mutation, writeCallbackInfo.getReplica(), null); + } + } + + void shutdownNow(boolean expireCallbacks) + { + executor.shutdownNow(); + if (expireCallbacks) + forceExpire(); + } + + void shutdownGracefully() + { + expire(); + if (!callbacks.isEmpty()) + executor.schedule(this::shutdownGracefully, 100L, MILLISECONDS); + else + executor.shutdownNow(); + } + + void awaitTerminationUntil(long deadlineNanos) throws TimeoutException, InterruptedException + { + if (!executor.isTerminated()) + { + long wait = deadlineNanos - System.nanoTime(); + if (wait <= 0 || !executor.awaitTermination(wait, NANOSECONDS)) + throw new TimeoutException(); + } + } + + @VisibleForTesting + public void unsafeClear() + { + callbacks.clear(); + } + + private static CallbackKey key(long id, InetAddressAndPort peer) + { + return new CallbackKey(id, peer); + } + + private static class CallbackKey + { + final long id; + final InetAddressAndPort peer; + + CallbackKey(long id, InetAddressAndPort peer) + { + this.id = id; + this.peer = peer; + } + + @Override + public boolean equals(Object o) + { + if (!(o instanceof CallbackKey)) + return false; + CallbackKey that = (CallbackKey) o; + return this.id == that.id && this.peer.equals(that.peer); + } + + @Override + public int hashCode() + { + return Long.hashCode(id) + 31 * peer.hashCode(); + } + + @Override + public String toString() + { + return "{id:" + id + ", peer:" + peer + '}'; + } + } + + static class CallbackInfo + { + final long createdAtNanos; + final long expiresAtNanos; + + final InetAddressAndPort peer; + final RequestCallback callback; + + @Deprecated // for 3.0 compatibility purposes only + public final Verb responseVerb; + + private CallbackInfo(Message message, InetAddressAndPort peer, RequestCallback callback) + { + this.createdAtNanos = message.createdAtNanos(); + this.expiresAtNanos = message.expiresAtNanos(); + this.peer = peer; + this.callback = callback; + this.responseVerb = message.verb().responseVerb; + } + + public long timeout() + { + return expiresAtNanos - createdAtNanos; + } + + boolean isReadyToDieAt(long atNano) + { + return atNano > expiresAtNanos; + } + + boolean shouldHint() + { + return false; + } + + boolean invokeOnFailure() + { + return callback.invokeOnFailure(); + } + + public String toString() + { + return "{peer:" + peer + ", callback:" + callback + ", invokeOnFailure:" + invokeOnFailure() + '}'; + } + } + + // FIXME: shouldn't need a specialized container for write callbacks; hinting should be part of + // AbstractWriteResponseHandler implementation. + static class WriteCallbackInfo extends CallbackInfo + { + // either a Mutation, or a Paxos Commit (MessageOut) + private final Object mutation; + private final Replica replica; + + @VisibleForTesting + WriteCallbackInfo(Message message, Replica replica, RequestCallback callback, ConsistencyLevel consistencyLevel, boolean allowHints) + { + super(message, replica.endpoint(), callback); + this.mutation = shouldHint(allowHints, message, consistencyLevel) ? message.payload : null; + //Local writes shouldn't go through messaging service (https://issues.apache.org/jira/browse/CASSANDRA-10477) + //noinspection AssertWithSideEffects + assert !peer.equals(FBUtilities.getBroadcastAddressAndPort()); + this.replica = replica; + } + + public boolean shouldHint() + { + return mutation != null && StorageProxy.shouldHint(replica); + } + + public Replica getReplica() + { + return replica; + } + + public Mutation mutation() + { + return getMutation(mutation); + } + + private static Mutation getMutation(Object object) + { + assert object instanceof Commit || object instanceof Mutation : object; + return object instanceof Commit ? ((Commit) object).makeMutation() + : (Mutation) object; + } + + private static boolean shouldHint(boolean allowHints, Message sentMessage, ConsistencyLevel consistencyLevel) + { + return allowHints && sentMessage.verb() != Verb.COUNTER_MUTATION_REQ && consistencyLevel != ConsistencyLevel.ANY; + } + } + + @Override + public void onOverloaded(Message message, InetAddressAndPort peer) + { + removeAndExpire(message, peer); + } + + @Override + public void onExpired(Message message, InetAddressAndPort peer) + { + removeAndExpire(message, peer); + } + + @Override + public void onFailedSerialize(Message message, InetAddressAndPort peer, int messagingVersion, int bytesWrittenToNetwork, Throwable failure) + { + removeAndExpire(message, peer); + } + + @Override + public void onDiscardOnClose(Message message, InetAddressAndPort peer) + { + removeAndExpire(message, peer); + } + + private void removeAndExpire(Message message, InetAddressAndPort peer) + { + removeAndExpire(message.id(), peer); + + /* in case of a write sent to a different DC, also expire all forwarding targets */ + ForwardingInfo forwardTo = message.forwardTo(); + if (null != forwardTo) + forwardTo.forEach(this::removeAndExpire); + } +} diff --git a/src/java/org/apache/cassandra/net/ResourceLimits.java b/src/java/org/apache/cassandra/net/ResourceLimits.java new file mode 100644 index 000000000000..f8d24d778bdc --- /dev/null +++ b/src/java/org/apache/cassandra/net/ResourceLimits.java @@ -0,0 +1,245 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.cassandra.net; + +import java.util.concurrent.atomic.AtomicLongFieldUpdater; + +public abstract class ResourceLimits +{ + /** + * Represents permits to utilise a resource and ways to allocate and release them. + * + * Two implementations are currently provided: + * 1. {@link Concurrent}, for shared limits, which is thread-safe; + * 2. {@link Basic}, for limits that are not shared between threads, is not thread-safe. + */ + public interface Limit + { + /** + * @return total amount of permits represented by this {@link Limit} - the capacity + */ + long limit(); + + /** + * @return remaining, unallocated permit amount + */ + long remaining(); + + /** + * @return amount of permits currently in use + */ + long using(); + + /** + * Attempts to allocate an amount of permits from this limit. If allocated, MUST eventually + * be released back with {@link #release(long)}. + * + * @return {@code true} if the allocation was successful, {@code false} otherwise + */ + boolean tryAllocate(long amount); + + /** + * Allocates an amount independent of permits available from this limit. MUST eventually + * be released back with {@link #release(long)}. + * + */ + void allocate(long amount); + + /** + * @param amount return the amount of permits back to this limit + * @return {@code ABOVE_LIMIT} if there aren't enough permits available even after the release, or + * {@code BELOW_LIMIT} if there are enough permits available after the releaese. + */ + Outcome release(long amount); + } + + /** + * A thread-safe permit container. + */ + public static class Concurrent implements Limit + { + private final long limit; + + private volatile long using; + private static final AtomicLongFieldUpdater usingUpdater = + AtomicLongFieldUpdater.newUpdater(Concurrent.class, "using"); + + public Concurrent(long limit) + { + this.limit = limit; + } + + public long limit() + { + return limit; + } + + public long remaining() + { + return limit - using; + } + + public long using() + { + return using; + } + + public boolean tryAllocate(long amount) + { + long current, next; + do + { + current = using; + next = current + amount; + + if (next > limit) + return false; + } + while (!usingUpdater.compareAndSet(this, current, next)); + + return true; + } + + public void allocate(long amount) + { + long current, next; + do + { + current = using; + next = current + amount; + } while (!usingUpdater.compareAndSet(this, current, next)); + } + + public Outcome release(long amount) + { + assert amount >= 0; + long using = usingUpdater.addAndGet(this, -amount); + assert using >= 0; + return using >= limit ? Outcome.ABOVE_LIMIT : Outcome.BELOW_LIMIT; + } + } + + /** + * A cheaper, thread-unsafe permit container to be used for unshared limits. + */ + static class Basic implements Limit + { + private final long limit; + private long using; + + Basic(long limit) + { + this.limit = limit; + } + + public long limit() + { + return limit; + } + + public long remaining() + { + return limit - using; + } + + public long using() + { + return using; + } + + public boolean tryAllocate(long amount) + { + if (using + amount > limit) + return false; + + using += amount; + return true; + } + + public void allocate(long amount) + { + using += amount; + } + + public Outcome release(long amount) + { + assert amount >= 0 && amount <= using; + using -= amount; + return using >= limit ? Outcome.ABOVE_LIMIT : Outcome.BELOW_LIMIT; + } + } + + /** + * A convenience class that groups a per-endpoint limit with the global one + * to allow allocating/releasing permits from/to both limits as one logical operation. + */ + public static class EndpointAndGlobal + { + final Limit endpoint; + final Limit global; + + public EndpointAndGlobal(Limit endpoint, Limit global) + { + this.endpoint = endpoint; + this.global = global; + } + + public Limit endpoint() + { + return endpoint; + } + + public Limit global() + { + return global; + } + + /** + * @return {@code INSUFFICIENT_GLOBAL} if there weren't enough permits in the global limit, or + * {@code INSUFFICIENT_ENDPOINT} if there weren't enough permits in the per-endpoint limit, or + * {@code SUCCESS} if there were enough permits to take from both. + */ + public Outcome tryAllocate(long amount) + { + if (!global.tryAllocate(amount)) + return Outcome.INSUFFICIENT_GLOBAL; + + if (endpoint.tryAllocate(amount)) + return Outcome.SUCCESS; + + global.release(amount); + return Outcome.INSUFFICIENT_ENDPOINT; + } + + public void allocate(long amount) + { + global.allocate(amount); + endpoint.allocate(amount); + } + + public Outcome release(long amount) + { + Outcome endpointReleaseOutcome = endpoint.release(amount); + Outcome globalReleaseOutcome = global.release(amount); + return (endpointReleaseOutcome == Outcome.ABOVE_LIMIT || globalReleaseOutcome == Outcome.ABOVE_LIMIT) + ? Outcome.ABOVE_LIMIT : Outcome.BELOW_LIMIT; + } + } + + public enum Outcome { SUCCESS, INSUFFICIENT_ENDPOINT, INSUFFICIENT_GLOBAL, BELOW_LIMIT, ABOVE_LIMIT } +} diff --git a/src/java/org/apache/cassandra/net/ResponseVerbHandler.java b/src/java/org/apache/cassandra/net/ResponseVerbHandler.java index fe22e42e356b..e5779ab91938 100644 --- a/src/java/org/apache/cassandra/net/ResponseVerbHandler.java +++ b/src/java/org/apache/cassandra/net/ResponseVerbHandler.java @@ -17,45 +17,50 @@ */ package org.apache.cassandra.net; -import java.util.concurrent.TimeUnit; - import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.apache.cassandra.exceptions.RequestFailureReason; import org.apache.cassandra.tracing.Tracing; -public class ResponseVerbHandler implements IVerbHandler +import static java.util.concurrent.TimeUnit.NANOSECONDS; +import static org.apache.cassandra.utils.MonotonicClock.approxTime; + +class ResponseVerbHandler implements IVerbHandler { - private static final Logger logger = LoggerFactory.getLogger( ResponseVerbHandler.class ); + public static final ResponseVerbHandler instance = new ResponseVerbHandler(); - public void doVerb(MessageIn message, int id) + private static final Logger logger = LoggerFactory.getLogger(ResponseVerbHandler.class); + + @Override + public void doVerb(Message message) { - long latency = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - MessagingService.instance().getRegisteredCallbackAge(id)); - CallbackInfo callbackInfo = MessagingService.instance().removeRegisteredCallback(id); + RequestCallbacks.CallbackInfo callbackInfo = MessagingService.instance().callbacks.remove(message.id(), message.from()); if (callbackInfo == null) { String msg = "Callback already removed for {} (from {})"; - logger.trace(msg, id, message.from); - Tracing.trace(msg, id, message.from); + logger.trace(msg, message.id(), message.from()); + Tracing.trace(msg, message.id(), message.from()); return; } - Tracing.trace("Processing response from {}", message.from); - IAsyncCallback cb = callbackInfo.callback; + long latencyNanos = approxTime.now() - callbackInfo.createdAtNanos; + Tracing.trace("Processing response from {}", message.from()); + + RequestCallback cb = callbackInfo.callback; if (message.isFailureResponse()) { - ((IAsyncCallbackWithFailure) cb).onFailure(message.from, message.getFailureReason()); + cb.onFailure(message.from(), (RequestFailureReason) message.payload); } else { - //TODO: Should we add latency only in success cases? - MessagingService.instance().maybeAddLatency(cb, message.from, latency); - cb.response(message); + MessagingService.instance().latencySubscribers.maybeAdd(cb, message.from(), latencyNanos, NANOSECONDS); + cb.onResponse(message); } if (callbackInfo.callback.supportsBackPressure()) { - MessagingService.instance().updateBackPressureOnReceive(message.from, cb, false); + MessagingService.instance().updateBackPressureOnReceive(message.from(), cb, false); } } } diff --git a/src/java/org/apache/cassandra/net/ShareableBytes.java b/src/java/org/apache/cassandra/net/ShareableBytes.java new file mode 100644 index 000000000000..e4f24608e4d2 --- /dev/null +++ b/src/java/org/apache/cassandra/net/ShareableBytes.java @@ -0,0 +1,174 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.cassandra.net; + +import java.nio.ByteBuffer; +import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; + +import org.apache.cassandra.utils.memory.BufferPool; + +/** + * A wrapper for possibly sharing portions of a single, {@link BufferPool} managed, {@link ByteBuffer}; + * optimised for the case where no sharing is necessary. + * + * When sharing is necessary, {@link #share()} method must be invoked by the owning thread + * before a {@link ShareableBytes} instance can be shared with another thread. + */ +class ShareableBytes +{ + private final ByteBuffer bytes; + private final ShareableBytes owner; + private volatile int count; + + private static final int UNSHARED = -1; + private static final int RELEASED = 0; + private static final AtomicIntegerFieldUpdater countUpdater = + AtomicIntegerFieldUpdater.newUpdater(ShareableBytes.class, "count"); + + private ShareableBytes(ByteBuffer bytes) + { + this.count = UNSHARED; + this.owner = this; + this.bytes = bytes; + } + + private ShareableBytes(ShareableBytes owner, ByteBuffer bytes) + { + this.owner = owner; + this.bytes = bytes; + } + + ByteBuffer get() + { + assert owner.count != 0; + return bytes; + } + + boolean hasRemaining() + { + return bytes.hasRemaining(); + } + + int remaining() + { + return bytes.remaining(); + } + + void skipBytes(int skipBytes) + { + bytes.position(bytes.position() + skipBytes); + } + + void consume() + { + bytes.position(bytes.limit()); + } + + /** + * Ensure this ShareableBytes will use atomic operations for updating its count from now on. + * The first invocation must occur while the calling thread has exclusive access (though there may be more + * than one 'owner', these must all either be owned by the calling thread or otherwise not being used) + */ + ShareableBytes share() + { + int count = owner.count; + if (count < 0) + owner.count = -count; + return this; + } + + private ShareableBytes retain() + { + owner.doRetain(); + return this; + } + + private void doRetain() + { + int count = this.count; + if (count < 0) + { + countUpdater.lazySet(this, count - 1); + return; + } + + while (true) + { + if (count == RELEASED) + throw new IllegalStateException("Attempted to reference an already released SharedByteBuffer"); + + if (countUpdater.compareAndSet(this, count, count + 1)) + return; + + count = this.count; + } + } + + void release() + { + owner.doRelease(); + } + + private void doRelease() + { + int count = this.count; + + if (count < 0) + countUpdater.lazySet(this, count += 1); + else if (count > 0) + count = countUpdater.decrementAndGet(this); + else + throw new IllegalStateException("Already released"); + + if (count == RELEASED) + BufferPool.put(bytes); + } + + boolean isReleased() + { + return owner.count == RELEASED; + } + + /** + * Create a slice over the next {@code length} bytes, consuming them from our buffer, and incrementing the owner count + */ + ShareableBytes sliceAndConsume(int length) + { + int begin = bytes.position(); + int end = begin + length; + ShareableBytes result = slice(begin, end); + bytes.position(end); + return result; + } + + /** + * Create a new slice, incrementing the number of owners (making it shared if it was previously unshared) + */ + ShareableBytes slice(int begin, int end) + { + ByteBuffer slice = bytes.duplicate(); + slice.position(begin).limit(end); + return new ShareableBytes(owner.retain(), slice); + } + + static ShareableBytes wrap(ByteBuffer buffer) + { + return new ShareableBytes(buffer); + } +} + diff --git a/src/java/org/apache/cassandra/net/SharedDefaultFileRegion.java b/src/java/org/apache/cassandra/net/SharedDefaultFileRegion.java new file mode 100644 index 000000000000..6b47c22d240e --- /dev/null +++ b/src/java/org/apache/cassandra/net/SharedDefaultFileRegion.java @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.net; + +import java.nio.channels.FileChannel; +import java.util.concurrent.atomic.AtomicInteger; + +import io.netty.channel.DefaultFileRegion; +import org.apache.cassandra.utils.concurrent.Ref; +import org.apache.cassandra.utils.concurrent.RefCounted; + +/** + * Netty's DefaultFileRegion closes the underlying FileChannel as soon as + * the refCnt() for the region drops to zero, this is an implementation of + * the DefaultFileRegion that doesn't close the FileChannel. + * + * See {@link AsyncChannelOutputPlus} for its usage. + */ +public class SharedDefaultFileRegion extends DefaultFileRegion +{ + public static class SharedFileChannel + { + // we don't call .ref() on this, because it would generate a lot of PhantomReferences and GC overhead, + // but we use it to ensure we can spot memory leaks + final Ref ref; + final AtomicInteger refCount = new AtomicInteger(1); + + SharedFileChannel(FileChannel fileChannel) + { + this.ref = new Ref<>(fileChannel, new RefCounted.Tidy() + { + public void tidy() throws Exception + { + // don't mind invoking this on eventLoop, as only used with sendFile which is also blocking + // so must use streaming eventLoop + fileChannel.close(); + } + + public String name() + { + return "SharedFileChannel[" + fileChannel.toString() + ']'; + } + }); + } + + public void release() + { + if (0 == refCount.decrementAndGet()) + ref.release(); + } + } + + private final SharedFileChannel shared; + private boolean deallocated = false; + + SharedDefaultFileRegion(SharedFileChannel shared, long position, long count) + { + super(shared.ref.get(), position, count); + this.shared = shared; + if (1 >= this.shared.refCount.incrementAndGet()) + throw new IllegalStateException(); + } + + @Override + protected void deallocate() + { + if (deallocated) + return; + deallocated = true; + shared.release(); + } + + public static SharedFileChannel share(FileChannel fileChannel) + { + return new SharedFileChannel(fileChannel); + } +} diff --git a/src/java/org/apache/cassandra/net/SocketFactory.java b/src/java/org/apache/cassandra/net/SocketFactory.java new file mode 100644 index 000000000000..da2d4612a1ee --- /dev/null +++ b/src/java/org/apache/cassandra/net/SocketFactory.java @@ -0,0 +1,300 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.cassandra.net; + +import java.io.IOException; +import java.net.ConnectException; +import java.net.InetSocketAddress; +import java.nio.channels.ClosedChannelException; +import java.nio.channels.spi.SelectorProvider; +import java.util.List; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.ThreadFactory; +import java.util.concurrent.TimeoutException; +import javax.annotation.Nullable; +import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLParameters; + +import com.google.common.collect.ImmutableList; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import io.netty.bootstrap.Bootstrap; +import io.netty.bootstrap.ServerBootstrap; +import io.netty.channel.Channel; +import io.netty.channel.ChannelFactory; +import io.netty.channel.DefaultSelectStrategyFactory; +import io.netty.channel.EventLoop; +import io.netty.channel.EventLoopGroup; +import io.netty.channel.ServerChannel; +import io.netty.channel.epoll.EpollChannelOption; +import io.netty.channel.epoll.EpollEventLoopGroup; +import io.netty.channel.epoll.EpollServerSocketChannel; +import io.netty.channel.epoll.EpollSocketChannel; +import io.netty.channel.nio.NioEventLoopGroup; +import io.netty.channel.socket.nio.NioServerSocketChannel; +import io.netty.channel.socket.nio.NioSocketChannel; +import io.netty.channel.unix.Errors; +import io.netty.handler.ssl.OpenSsl; +import io.netty.handler.ssl.SslContext; +import io.netty.handler.ssl.SslHandler; +import io.netty.util.concurrent.DefaultEventExecutorChooserFactory; +import io.netty.util.concurrent.DefaultThreadFactory; +import io.netty.util.concurrent.RejectedExecutionHandlers; +import io.netty.util.concurrent.ThreadPerTaskExecutor; +import io.netty.util.internal.logging.InternalLoggerFactory; +import io.netty.util.internal.logging.Slf4JLoggerFactory; +import org.apache.cassandra.concurrent.NamedThreadFactory; +import org.apache.cassandra.config.Config; +import org.apache.cassandra.config.EncryptionOptions; +import org.apache.cassandra.locator.InetAddressAndPort; +import org.apache.cassandra.security.SSLFactory; +import org.apache.cassandra.service.NativeTransportService; +import org.apache.cassandra.utils.ExecutorUtils; +import org.apache.cassandra.utils.FBUtilities; + +import static io.netty.channel.unix.Errors.ERRNO_ECONNRESET_NEGATIVE; +import static io.netty.channel.unix.Errors.ERROR_ECONNREFUSED_NEGATIVE; +import static java.util.concurrent.TimeUnit.SECONDS; +import static org.apache.cassandra.utils.Throwables.isCausedBy; + +/** + * A factory for building Netty {@link Channel}s. Channels here are setup with a pipeline to participate + * in the internode protocol handshake, either the inbound or outbound side as per the method invoked. + */ +public final class SocketFactory +{ + private static final Logger logger = LoggerFactory.getLogger(SocketFactory.class); + + private static final int EVENT_THREADS = Integer.getInteger(Config.PROPERTY_PREFIX + "internode-event-threads", FBUtilities.getAvailableProcessors()); + + /** + * The default task queue used by {@code NioEventLoop} and {@code EpollEventLoop} is {@code MpscUnboundedArrayQueue}, + * provided by JCTools. While efficient, it has an undesirable quality for a queue backing an event loop: it is + * not non-blocking, and can cause the event loop to busy-spin while waiting for a partially completed task + * offer, if the producer thread has been suspended mid-offer. + * + * As it happens, however, we have an MPSC queue implementation that is perfectly fit for this purpose - + * {@link ManyToOneConcurrentLinkedQueue}, that is non-blocking, and already used throughout the codebase, + * that we can and do use here as well. + */ + enum Provider + { + NIO + { + @Override + NioEventLoopGroup makeEventLoopGroup(int threadCount, ThreadFactory threadFactory) + { + return new NioEventLoopGroup(threadCount, + new ThreadPerTaskExecutor(threadFactory), + DefaultEventExecutorChooserFactory.INSTANCE, + SelectorProvider.provider(), + DefaultSelectStrategyFactory.INSTANCE, + RejectedExecutionHandlers.reject(), + capacity -> new ManyToOneConcurrentLinkedQueue<>()); + } + + @Override + ChannelFactory clientChannelFactory() + { + return NioSocketChannel::new; + } + + @Override + ChannelFactory serverChannelFactory() + { + return NioServerSocketChannel::new; + } + }, + EPOLL + { + @Override + EpollEventLoopGroup makeEventLoopGroup(int threadCount, ThreadFactory threadFactory) + { + return new EpollEventLoopGroup(threadCount, + new ThreadPerTaskExecutor(threadFactory), + DefaultEventExecutorChooserFactory.INSTANCE, + DefaultSelectStrategyFactory.INSTANCE, + RejectedExecutionHandlers.reject(), + capacity -> new ManyToOneConcurrentLinkedQueue<>()); + } + + @Override + ChannelFactory clientChannelFactory() + { + return EpollSocketChannel::new; + } + + @Override + ChannelFactory serverChannelFactory() + { + return EpollServerSocketChannel::new; + } + }; + + EventLoopGroup makeEventLoopGroup(int threadCount, String threadNamePrefix) + { + logger.debug("using netty {} event loop for pool prefix {}", name(), threadNamePrefix); + return makeEventLoopGroup(threadCount, new DefaultThreadFactory(threadNamePrefix, true)); + } + + abstract EventLoopGroup makeEventLoopGroup(int threadCount, ThreadFactory threadFactory); + abstract ChannelFactory clientChannelFactory(); + abstract ChannelFactory serverChannelFactory(); + + static Provider optimalProvider() + { + return NativeTransportService.useEpoll() ? EPOLL : NIO; + } + } + + /** a useful addition for debugging; simply set to true to get more data in your logs */ + static final boolean WIRETRACE = false; + static + { + if (WIRETRACE) + InternalLoggerFactory.setDefaultFactory(Slf4JLoggerFactory.INSTANCE); + } + + private final Provider provider; + private final EventLoopGroup acceptGroup; + private final EventLoopGroup defaultGroup; + // we need a separate EventLoopGroup for outbound streaming because sendFile is blocking + private final EventLoopGroup outboundStreamingGroup; + final ExecutorService synchronousWorkExecutor = Executors.newCachedThreadPool(new NamedThreadFactory("Messaging-SynchronousWork")); + + SocketFactory() + { + this(Provider.optimalProvider()); + } + + SocketFactory(Provider provider) + { + this.provider = provider; + this.acceptGroup = provider.makeEventLoopGroup(1, "Messaging-AcceptLoop"); + this.defaultGroup = provider.makeEventLoopGroup(EVENT_THREADS, NamedThreadFactory.globalPrefix() + "Messaging-EventLoop"); + this.outboundStreamingGroup = provider.makeEventLoopGroup(EVENT_THREADS, "Streaming-EventLoop"); + } + + Bootstrap newClientBootstrap(EventLoop eventLoop, int tcpUserTimeoutInMS) + { + if (eventLoop == null) + throw new IllegalArgumentException("must provide eventLoop"); + + Bootstrap bootstrap = new Bootstrap().group(eventLoop).channelFactory(provider.clientChannelFactory()); + + if (provider == Provider.EPOLL) + bootstrap.option(EpollChannelOption.TCP_USER_TIMEOUT, tcpUserTimeoutInMS); + + return bootstrap; + } + + ServerBootstrap newServerBootstrap() + { + return new ServerBootstrap().group(acceptGroup, defaultGroup).channelFactory(provider.serverChannelFactory()); + } + + /** + * Creates a new {@link SslHandler} from provided SslContext. + * @param peer enables endpoint verification for remote address when not null + */ + static SslHandler newSslHandler(Channel channel, SslContext sslContext, @Nullable InetSocketAddress peer) + { + if (peer == null) + return sslContext.newHandler(channel.alloc()); + + logger.debug("Creating SSL handler for {}:{}", peer.getHostString(), peer.getPort()); + SslHandler sslHandler = sslContext.newHandler(channel.alloc(), peer.getHostString(), peer.getPort()); + SSLEngine engine = sslHandler.engine(); + SSLParameters sslParameters = engine.getSSLParameters(); + sslParameters.setEndpointIdentificationAlgorithm("HTTPS"); + engine.setSSLParameters(sslParameters); + return sslHandler; + } + + static String encryptionLogStatement(EncryptionOptions options) + { + if (options == null) + return "disabled"; + + String encryptionType = SSLFactory.openSslIsAvailable() ? "openssl" : "jdk"; + return "enabled (" + encryptionType + ')'; + } + + EventLoopGroup defaultGroup() + { + return defaultGroup; + } + + public EventLoopGroup outboundStreamingGroup() + { + return outboundStreamingGroup; + } + + public void shutdownNow() + { + acceptGroup.shutdownGracefully(0, 2, SECONDS); + defaultGroup.shutdownGracefully(0, 2, SECONDS); + outboundStreamingGroup.shutdownGracefully(0, 2, SECONDS); + synchronousWorkExecutor.shutdownNow(); + } + + void awaitTerminationUntil(long deadlineNanos) throws InterruptedException, TimeoutException + { + List groups = ImmutableList.of(acceptGroup, defaultGroup, outboundStreamingGroup, synchronousWorkExecutor); + ExecutorUtils.awaitTerminationUntil(deadlineNanos, groups); + } + + static boolean isConnectionReset(Throwable t) + { + if (t instanceof ClosedChannelException) + return true; + if (t instanceof ConnectException) + return true; + if (t instanceof Errors.NativeIoException) + { + int errorCode = ((Errors.NativeIoException) t).expectedErr(); + return errorCode == ERRNO_ECONNRESET_NEGATIVE || errorCode != ERROR_ECONNREFUSED_NEGATIVE; + } + return IOException.class == t.getClass() && ("Broken pipe".equals(t.getMessage()) || "Connection reset by peer".equals(t.getMessage())); + } + + static boolean isCausedByConnectionReset(Throwable t) + { + return isCausedBy(t, SocketFactory::isConnectionReset); + } + + static String channelId(InetAddressAndPort from, InetSocketAddress realFrom, InetAddressAndPort to, InetSocketAddress realTo, ConnectionType type, String id) + { + return addressId(from, realFrom) + "->" + addressId(to, realTo) + '-' + type + '-' + id; + } + + static String addressId(InetAddressAndPort address, InetSocketAddress realAddress) + { + String str = address.toString(); + if (!address.address.equals(realAddress.getAddress()) || address.port != realAddress.getPort()) + str += '(' + InetAddressAndPort.toString(realAddress.getAddress(), realAddress.getPort()) + ')'; + return str; + } + + static String channelId(InetAddressAndPort from, InetAddressAndPort to, ConnectionType type, String id) + { + return from + "->" + to + '-' + type + '-' + id; + } +} diff --git a/src/java/org/apache/cassandra/net/StartupClusterConnectivityChecker.java b/src/java/org/apache/cassandra/net/StartupClusterConnectivityChecker.java index 8e3747068c99..b901338006be 100644 --- a/src/java/org/apache/cassandra/net/StartupClusterConnectivityChecker.java +++ b/src/java/org/apache/cassandra/net/StartupClusterConnectivityChecker.java @@ -42,12 +42,11 @@ import org.apache.cassandra.gms.IEndpointStateChangeSubscriber; import org.apache.cassandra.gms.VersionedValue; import org.apache.cassandra.locator.InetAddressAndPort; -import org.apache.cassandra.net.async.OutboundConnectionIdentifier.ConnectionType; import org.apache.cassandra.utils.FBUtilities; -import static org.apache.cassandra.net.MessagingService.Verb.PING; -import static org.apache.cassandra.net.async.OutboundConnectionIdentifier.ConnectionType.LARGE_MESSAGE; -import static org.apache.cassandra.net.async.OutboundConnectionIdentifier.ConnectionType.SMALL_MESSAGE; +import static org.apache.cassandra.net.Verb.PING_REQ; +import static org.apache.cassandra.net.ConnectionType.LARGE_MESSAGES; +import static org.apache.cassandra.net.ConnectionType.SMALL_MESSAGES; public class StartupClusterConnectivityChecker { @@ -149,11 +148,11 @@ public boolean execute(Set peers, Function peers, Function peers, Map dcToRemainingPeers, AckMap acks, Function getDatacenter) { - IAsyncCallback responseHandler = new IAsyncCallback() - { - public boolean isLatencyForSnitch() - { - return false; - } - - public void response(MessageIn msg) + RequestCallback responseHandler = msg -> { + if (acks.incrementAndCheck(msg.from())) { - if (acks.incrementAndCheck(msg.from)) - { - String datacenter = getDatacenter.apply(msg.from); - // We have to check because we might only have the local DC in the map - if (dcToRemainingPeers.containsKey(datacenter)) - dcToRemainingPeers.get(datacenter).countDown(); - } + String datacenter = getDatacenter.apply(msg.from()); + // We have to check because we might only have the local DC in the map + if (dcToRemainingPeers.containsKey(datacenter)) + dcToRemainingPeers.get(datacenter).countDown(); } }; - MessageOut smallChannelMessageOut = new MessageOut<>(PING, PingMessage.smallChannelMessage, - PingMessage.serializer, SMALL_MESSAGE); - MessageOut largeChannelMessageOut = new MessageOut<>(PING, PingMessage.largeChannelMessage, - PingMessage.serializer, LARGE_MESSAGE); + Message small = Message.out(PING_REQ, PingRequest.forSmall); + Message large = Message.out(PING_REQ, PingRequest.forLarge); for (InetAddressAndPort peer : peers) { - MessagingService.instance().sendRR(smallChannelMessageOut, peer, responseHandler); - MessagingService.instance().sendRR(largeChannelMessageOut, peer, responseHandler); + MessagingService.instance().sendWithCallback(small, peer, responseHandler, SMALL_MESSAGES); + MessagingService.instance().sendWithCallback(large, peer, responseHandler, LARGE_MESSAGES); } } diff --git a/src/java/org/apache/cassandra/net/Verb.java b/src/java/org/apache/cassandra/net/Verb.java new file mode 100644 index 000000000000..67d847e939eb --- /dev/null +++ b/src/java/org/apache/cassandra/net/Verb.java @@ -0,0 +1,363 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.cassandra.net; + +import java.lang.reflect.Field; +import java.lang.reflect.Modifier; +import java.util.List; +import java.util.concurrent.TimeUnit; +import java.util.function.Supplier; +import java.util.function.ToLongFunction; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableList; + +import org.apache.cassandra.batchlog.Batch; +import org.apache.cassandra.batchlog.BatchRemoveVerbHandler; +import org.apache.cassandra.batchlog.BatchStoreVerbHandler; +import org.apache.cassandra.concurrent.Stage; +import org.apache.cassandra.config.DatabaseDescriptor; +import org.apache.cassandra.db.CounterMutation; +import org.apache.cassandra.db.CounterMutationVerbHandler; +import org.apache.cassandra.db.Mutation; +import org.apache.cassandra.db.MutationVerbHandler; +import org.apache.cassandra.db.ReadCommand; +import org.apache.cassandra.db.ReadCommandVerbHandler; +import org.apache.cassandra.db.ReadRepairVerbHandler; +import org.apache.cassandra.db.ReadResponse; +import org.apache.cassandra.db.SnapshotCommand; +import org.apache.cassandra.db.TruncateResponse; +import org.apache.cassandra.db.TruncateVerbHandler; +import org.apache.cassandra.db.TruncateRequest; +import org.apache.cassandra.exceptions.RequestFailureReason; +import org.apache.cassandra.gms.GossipDigestAck; +import org.apache.cassandra.gms.GossipDigestAck2; +import org.apache.cassandra.gms.GossipDigestAck2VerbHandler; +import org.apache.cassandra.gms.GossipDigestAckVerbHandler; +import org.apache.cassandra.gms.GossipDigestSyn; +import org.apache.cassandra.gms.GossipDigestSynVerbHandler; +import org.apache.cassandra.gms.GossipShutdownVerbHandler; +import org.apache.cassandra.hints.HintMessage; +import org.apache.cassandra.hints.HintVerbHandler; +import org.apache.cassandra.io.IVersionedAsymmetricSerializer; +import org.apache.cassandra.repair.RepairMessageVerbHandler; +import org.apache.cassandra.repair.messages.AsymmetricSyncRequest; +import org.apache.cassandra.repair.messages.CleanupMessage; +import org.apache.cassandra.repair.messages.FailSession; +import org.apache.cassandra.repair.messages.FinalizeCommit; +import org.apache.cassandra.repair.messages.FinalizePromise; +import org.apache.cassandra.repair.messages.FinalizePropose; +import org.apache.cassandra.repair.messages.PrepareConsistentRequest; +import org.apache.cassandra.repair.messages.PrepareConsistentResponse; +import org.apache.cassandra.repair.messages.PrepareMessage; +import org.apache.cassandra.repair.messages.SnapshotMessage; +import org.apache.cassandra.repair.messages.StatusRequest; +import org.apache.cassandra.repair.messages.StatusResponse; +import org.apache.cassandra.repair.messages.SyncResponse; +import org.apache.cassandra.repair.messages.SyncRequest; +import org.apache.cassandra.repair.messages.ValidationResponse; +import org.apache.cassandra.repair.messages.ValidationRequest; +import org.apache.cassandra.schema.SchemaPullVerbHandler; +import org.apache.cassandra.schema.SchemaPushVerbHandler; +import org.apache.cassandra.schema.SchemaVersionVerbHandler; +import org.apache.cassandra.utils.BooleanSerializer; +import org.apache.cassandra.service.EchoVerbHandler; +import org.apache.cassandra.service.SnapshotVerbHandler; +import org.apache.cassandra.service.paxos.Commit; +import org.apache.cassandra.service.paxos.CommitVerbHandler; +import org.apache.cassandra.service.paxos.PrepareResponse; +import org.apache.cassandra.service.paxos.PrepareVerbHandler; +import org.apache.cassandra.service.paxos.ProposeVerbHandler; +import org.apache.cassandra.streaming.ReplicationDoneVerbHandler; +import org.apache.cassandra.utils.UUIDSerializer; + +import static java.util.concurrent.TimeUnit.NANOSECONDS; +import static org.apache.cassandra.concurrent.Stage.*; +import static org.apache.cassandra.concurrent.Stage.INTERNAL_RESPONSE; +import static org.apache.cassandra.concurrent.Stage.MISC; +import static org.apache.cassandra.net.VerbTimeouts.*; +import static org.apache.cassandra.net.Verb.Priority.*; +import static org.apache.cassandra.schema.MigrationManager.MigrationsSerializer; + +/** + * Note that priorities except P0 are presently unused. P0 corresponds to urgent, i.e. what used to be the "Gossip" connection. + */ +public enum Verb +{ + MUTATION_RSP (60, P1, writeTimeout, REQUEST_RESPONSE, () -> NoPayload.serializer, () -> ResponseVerbHandler.instance ), + MUTATION_REQ (0, P3, writeTimeout, MUTATION, () -> Mutation.serializer, () -> MutationVerbHandler.instance, MUTATION_RSP ), + HINT_RSP (61, P1, writeTimeout, REQUEST_RESPONSE, () -> NoPayload.serializer, () -> ResponseVerbHandler.instance ), + HINT_REQ (1, P4, writeTimeout, MUTATION, () -> HintMessage.serializer, () -> HintVerbHandler.instance, HINT_RSP ), + READ_REPAIR_RSP (62, P1, writeTimeout, REQUEST_RESPONSE, () -> NoPayload.serializer, () -> ResponseVerbHandler.instance ), + READ_REPAIR_REQ (2, P1, writeTimeout, MUTATION, () -> Mutation.serializer, () -> ReadRepairVerbHandler.instance, READ_REPAIR_RSP ), + BATCH_STORE_RSP (65, P1, writeTimeout, REQUEST_RESPONSE, () -> NoPayload.serializer, () -> ResponseVerbHandler.instance ), + BATCH_STORE_REQ (5, P3, writeTimeout, MUTATION, () -> Batch.serializer, () -> BatchStoreVerbHandler.instance, BATCH_STORE_RSP ), + BATCH_REMOVE_RSP (66, P1, writeTimeout, REQUEST_RESPONSE, () -> NoPayload.serializer, () -> ResponseVerbHandler.instance ), + BATCH_REMOVE_REQ (6, P3, writeTimeout, MUTATION, () -> UUIDSerializer.serializer, () -> BatchRemoveVerbHandler.instance, BATCH_REMOVE_RSP ), + + PAXOS_PREPARE_RSP (93, P2, writeTimeout, REQUEST_RESPONSE, () -> PrepareResponse.serializer, () -> ResponseVerbHandler.instance ), + PAXOS_PREPARE_REQ (33, P2, writeTimeout, MUTATION, () -> Commit.serializer, () -> PrepareVerbHandler.instance, PAXOS_PREPARE_RSP ), + PAXOS_PROPOSE_RSP (94, P2, writeTimeout, REQUEST_RESPONSE, () -> BooleanSerializer.serializer, () -> ResponseVerbHandler.instance ), + PAXOS_PROPOSE_REQ (34, P2, writeTimeout, MUTATION, () -> Commit.serializer, () -> ProposeVerbHandler.instance, PAXOS_PROPOSE_RSP ), + PAXOS_COMMIT_RSP (95, P2, writeTimeout, REQUEST_RESPONSE, () -> NoPayload.serializer, () -> ResponseVerbHandler.instance ), + PAXOS_COMMIT_REQ (35, P2, writeTimeout, MUTATION, () -> Commit.serializer, () -> CommitVerbHandler.instance, PAXOS_COMMIT_RSP ), + + TRUNCATE_RSP (79, P0, truncateTimeout, REQUEST_RESPONSE, () -> TruncateResponse.serializer, () -> ResponseVerbHandler.instance ), + TRUNCATE_REQ (19, P0, truncateTimeout, MUTATION, () -> TruncateRequest.serializer, () -> TruncateVerbHandler.instance, TRUNCATE_RSP ), + + COUNTER_MUTATION_RSP (84, P1, counterTimeout, REQUEST_RESPONSE, () -> NoPayload.serializer, () -> ResponseVerbHandler.instance ), + COUNTER_MUTATION_REQ (24, P2, counterTimeout, COUNTER_MUTATION, () -> CounterMutation.serializer, () -> CounterMutationVerbHandler.instance, COUNTER_MUTATION_RSP), + + READ_RSP (63, P2, readTimeout, REQUEST_RESPONSE, () -> ReadResponse.serializer, () -> ResponseVerbHandler.instance ), + READ_REQ (3, P3, readTimeout, READ, () -> ReadCommand.serializer, () -> ReadCommandVerbHandler.instance, READ_RSP ), + RANGE_RSP (69, P2, rangeTimeout, REQUEST_RESPONSE, () -> ReadResponse.serializer, () -> ResponseVerbHandler.instance ), + RANGE_REQ (9, P3, rangeTimeout, READ, () -> ReadCommand.serializer, () -> ReadCommandVerbHandler.instance, RANGE_RSP ), + + GOSSIP_DIGEST_SYN (14, P0, longTimeout, GOSSIP, () -> GossipDigestSyn.serializer, () -> GossipDigestSynVerbHandler.instance ), + GOSSIP_DIGEST_ACK (15, P0, longTimeout, GOSSIP, () -> GossipDigestAck.serializer, () -> GossipDigestAckVerbHandler.instance ), + GOSSIP_DIGEST_ACK2 (16, P0, longTimeout, GOSSIP, () -> GossipDigestAck2.serializer, () -> GossipDigestAck2VerbHandler.instance ), + GOSSIP_SHUTDOWN (29, P0, rpcTimeout, GOSSIP, () -> NoPayload.serializer, () -> GossipShutdownVerbHandler.instance ), + + ECHO_RSP (91, P0, rpcTimeout, GOSSIP, () -> NoPayload.serializer, () -> ResponseVerbHandler.instance ), + ECHO_REQ (31, P0, rpcTimeout, GOSSIP, () -> NoPayload.serializer, () -> EchoVerbHandler.instance, ECHO_RSP ), + PING_RSP (97, P1, pingTimeout, GOSSIP, () -> NoPayload.serializer, () -> ResponseVerbHandler.instance ), + PING_REQ (37, P1, pingTimeout, GOSSIP, () -> PingRequest.serializer, () -> PingVerbHandler.instance, PING_RSP ), + + // P1 because messages can be arbitrarily large or aren't crucial + SCHEMA_PUSH_RSP (98, P1, rpcTimeout, MIGRATION, () -> NoPayload.serializer, () -> ResponseVerbHandler.instance ), + SCHEMA_PUSH_REQ (18, P1, rpcTimeout, MIGRATION, () -> MigrationsSerializer.instance, () -> SchemaPushVerbHandler.instance, SCHEMA_PUSH_RSP ), + SCHEMA_PULL_RSP (88, P1, rpcTimeout, MIGRATION, () -> MigrationsSerializer.instance, () -> ResponseVerbHandler.instance ), + SCHEMA_PULL_REQ (28, P1, rpcTimeout, MIGRATION, () -> NoPayload.serializer, () -> SchemaPullVerbHandler.instance, SCHEMA_PULL_RSP ), + SCHEMA_VERSION_RSP (80, P1, rpcTimeout, MIGRATION, () -> UUIDSerializer.serializer, () -> ResponseVerbHandler.instance ), + SCHEMA_VERSION_REQ (20, P1, rpcTimeout, MIGRATION, () -> NoPayload.serializer, () -> SchemaVersionVerbHandler.instance, SCHEMA_VERSION_RSP ), + + // repair; mostly doesn't use callbacks and sends responses as their own request messages, with matching sessions by uuid; should eventually harmonize and make idiomatic + REPAIR_RSP (100, P1, rpcTimeout, REQUEST_RESPONSE, () -> NoPayload.serializer, () -> ResponseVerbHandler.instance ), + VALIDATION_RSP (102, P1, rpcTimeout, ANTI_ENTROPY, () -> ValidationResponse.serializer, () -> RepairMessageVerbHandler.instance, REPAIR_RSP ), + VALIDATION_REQ (101, P1, rpcTimeout, ANTI_ENTROPY, () -> ValidationRequest.serializer, () -> RepairMessageVerbHandler.instance, REPAIR_RSP ), + SYNC_RSP (104, P1, rpcTimeout, ANTI_ENTROPY, () -> SyncResponse.serializer, () -> RepairMessageVerbHandler.instance, REPAIR_RSP ), + SYNC_REQ (103, P1, rpcTimeout, ANTI_ENTROPY, () -> SyncRequest.serializer, () -> RepairMessageVerbHandler.instance, REPAIR_RSP ), + PREPARE_MSG (105, P1, rpcTimeout, ANTI_ENTROPY, () -> PrepareMessage.serializer, () -> RepairMessageVerbHandler.instance, REPAIR_RSP ), + SNAPSHOT_MSG (106, P1, rpcTimeout, ANTI_ENTROPY, () -> SnapshotMessage.serializer, () -> RepairMessageVerbHandler.instance, REPAIR_RSP ), + CLEANUP_MSG (107, P1, rpcTimeout, ANTI_ENTROPY, () -> CleanupMessage.serializer, () -> RepairMessageVerbHandler.instance, REPAIR_RSP ), + PREPARE_CONSISTENT_RSP (109, P1, rpcTimeout, ANTI_ENTROPY, () -> PrepareConsistentResponse.serializer, () -> RepairMessageVerbHandler.instance, REPAIR_RSP ), + PREPARE_CONSISTENT_REQ (108, P1, rpcTimeout, ANTI_ENTROPY, () -> PrepareConsistentRequest.serializer, () -> RepairMessageVerbHandler.instance, REPAIR_RSP ), + FINALIZE_PROPOSE_MSG (110, P1, rpcTimeout, ANTI_ENTROPY, () -> FinalizePropose.serializer, () -> RepairMessageVerbHandler.instance, REPAIR_RSP ), + FINALIZE_PROMISE_MSG (111, P1, rpcTimeout, ANTI_ENTROPY, () -> FinalizePromise.serializer, () -> RepairMessageVerbHandler.instance, REPAIR_RSP ), + FINALIZE_COMMIT_MSG (112, P1, rpcTimeout, ANTI_ENTROPY, () -> FinalizeCommit.serializer, () -> RepairMessageVerbHandler.instance, REPAIR_RSP ), + FAILED_SESSION_MSG (113, P1, rpcTimeout, ANTI_ENTROPY, () -> FailSession.serializer, () -> RepairMessageVerbHandler.instance, REPAIR_RSP ), + STATUS_RSP (115, P1, rpcTimeout, ANTI_ENTROPY, () -> StatusResponse.serializer, () -> RepairMessageVerbHandler.instance, REPAIR_RSP ), + STATUS_REQ (114, P1, rpcTimeout, ANTI_ENTROPY, () -> StatusRequest.serializer, () -> RepairMessageVerbHandler.instance, REPAIR_RSP ), + ASYMMETRIC_SYNC_REQ (116, P1, rpcTimeout, ANTI_ENTROPY, () -> AsymmetricSyncRequest.serializer, () -> RepairMessageVerbHandler.instance, REPAIR_RSP ), + + REPLICATION_DONE_RSP (82, P0, rpcTimeout, MISC, () -> NoPayload.serializer, () -> ResponseVerbHandler.instance ), + REPLICATION_DONE_REQ (22, P0, rpcTimeout, MISC, () -> NoPayload.serializer, () -> ReplicationDoneVerbHandler.instance, REPLICATION_DONE_RSP), + SNAPSHOT_RSP (87, P0, rpcTimeout, MISC, () -> NoPayload.serializer, () -> ResponseVerbHandler.instance ), + SNAPSHOT_REQ (27, P0, rpcTimeout, MISC, () -> SnapshotCommand.serializer, () -> SnapshotVerbHandler.instance, SNAPSHOT_RSP ), + + // generic failure response + FAILURE_RSP (99, P0, noTimeout, REQUEST_RESPONSE, () -> RequestFailureReason.serializer, () -> ResponseVerbHandler.instance ), + + // dummy verbs + _TRACE (30, P1, rpcTimeout, TRACING, () -> NoPayload.serializer, () -> null ), + _SAMPLE (42, P1, rpcTimeout, INTERNAL_RESPONSE, () -> NoPayload.serializer, () -> null ), + _TEST_1 (10, P0, writeTimeout, IMMEDIATE, () -> NoPayload.serializer, () -> null ), + _TEST_2 (11, P1, rpcTimeout, IMMEDIATE, () -> NoPayload.serializer, () -> null ), + + @Deprecated + REQUEST_RSP (4, P1, rpcTimeout, REQUEST_RESPONSE, () -> null, () -> ResponseVerbHandler.instance ), + @Deprecated + INTERNAL_RSP (23, P1, rpcTimeout, INTERNAL_RESPONSE, () -> null, () -> ResponseVerbHandler.instance ), + + // largest used ID: 116 + ; + + public static final List VERBS = ImmutableList.copyOf(Verb.values()); + + public enum Priority + { + P0, // sends on the urgent connection (i.e. for Gossip, Echo) + P1, // small or empty responses + P2, // larger messages that can be dropped but who have a larger impact on system stability (e.g. READ_REPAIR, READ_RSP) + P3, + P4 + } + + public final int id; + public final Priority priority; + public final Stage stage; + + /** + * Messages we receive from peers have a Verb that tells us what kind of message it is. + * Most of the time, this is enough to determine how to deserialize the message payload. + * The exception is the REQUEST_RSP verb, which just means "a response to something you told me to do." + * Traditionally, this was fine since each VerbHandler knew what type of payload it expected, and + * handled the deserialization itself. Now that we do that in ITC, to avoid the extra copy to an + * intermediary byte[] (See CASSANDRA-3716), we need to wire that up to the CallbackInfo object + * (see below). + * + * NOTE: we use a Supplier to avoid loading the dependent classes until necessary. + */ + private final Supplier> serializer; + private final Supplier> handler; + + final Verb responseVerb; + + private final ToLongFunction expiration; + + + /** + * Verbs it's okay to drop if the request has been queued longer than the request timeout. These + * all correspond to client requests or something triggered by them; we don't want to + * drop internal messages like bootstrap or repair notifications. + */ + Verb(int id, Priority priority, ToLongFunction expiration, Stage stage, Supplier> serializer, Supplier> handler) + { + this(id, priority, expiration, stage, serializer, handler, null); + } + + Verb(int id, Priority priority, ToLongFunction expiration, Stage stage, Supplier> serializer, Supplier> handler, Verb responseVerb) + { + this.stage = stage; + if (id < 0) + throw new IllegalArgumentException("Verb id must be non-negative, got " + id + " for verb " + name()); + + this.id = id; + this.priority = priority; + this.serializer = serializer; + this.handler = handler; + this.responseVerb = responseVerb; + this.expiration = expiration; + } + + public IVersionedAsymmetricSerializer serializer() + { + return (IVersionedAsymmetricSerializer) serializer.get(); + } + + public IVerbHandler handler() + { + return (IVerbHandler) handler.get(); + } + + public long expiresAtNanos(long nowNanos) + { + return nowNanos + expiresAfterNanos(); + } + + public long expiresAfterNanos() + { + return expiration.applyAsLong(NANOSECONDS); + } + + // this is a little hacky, but reduces the number of parameters up top + public boolean isResponse() + { + return handler.get() == ResponseVerbHandler.instance; + } + + Verb toPre40Verb() + { + if (!isResponse()) + return this; + if (priority == P0) + return INTERNAL_RSP; + return REQUEST_RSP; + } + + @VisibleForTesting + Supplier> unsafeSetHandler(Supplier> handler) throws NoSuchFieldException, IllegalAccessException + { + Supplier> original = this.handler; + Field field = Verb.class.getDeclaredField("handler"); + field.setAccessible(true); + Field modifiers = Field.class.getDeclaredField("modifiers"); + modifiers.setAccessible(true); + modifiers.setInt(field, field.getModifiers() & ~Modifier.FINAL); + field.set(this, handler); + return original; + } + + @VisibleForTesting + Supplier> unsafeSetSerializer(Supplier> serializer) throws NoSuchFieldException, IllegalAccessException + { + Supplier> original = this.serializer; + Field field = Verb.class.getDeclaredField("serializer"); + field.setAccessible(true); + Field modifiers = Field.class.getDeclaredField("modifiers"); + modifiers.setAccessible(true); + modifiers.setInt(field, field.getModifiers() & ~Modifier.FINAL); + field.set(this, serializer); + return original; + } + + @VisibleForTesting + ToLongFunction unsafeSetExpiration(ToLongFunction expiration) throws NoSuchFieldException, IllegalAccessException + { + ToLongFunction original = this.expiration; + Field field = Verb.class.getDeclaredField("expiration"); + field.setAccessible(true); + Field modifiers = Field.class.getDeclaredField("modifiers"); + modifiers.setAccessible(true); + modifiers.setInt(field, field.getModifiers() & ~Modifier.FINAL); + field.set(this, expiration); + return original; + } + + private static final Verb[] idToVerbMap; + + static + { + Verb[] verbs = values(); + int max = -1; + for (Verb v : verbs) + max = Math.max(v.id, max); + + Verb[] idMap = new Verb[max + 1]; + for (Verb v : verbs) + { + if (idMap[v.id] != null) + throw new IllegalArgumentException("cannot have two verbs that map to the same id: " + v + " and " + idMap[v.id]); + idMap[v.id] = v; + } + + idToVerbMap = idMap; + } + + static Verb fromId(int id) + { + Verb verb = id >= 0 && id < idToVerbMap.length ? idToVerbMap[id] : null; + if (verb == null) + throw new IllegalArgumentException("Unknown verb id " + id); + return verb; + } +} + +@SuppressWarnings("unused") +class VerbTimeouts +{ + static final ToLongFunction rpcTimeout = DatabaseDescriptor::getRpcTimeout; + static final ToLongFunction writeTimeout = DatabaseDescriptor::getWriteRpcTimeout; + static final ToLongFunction readTimeout = DatabaseDescriptor::getReadRpcTimeout; + static final ToLongFunction rangeTimeout = DatabaseDescriptor::getRangeRpcTimeout; + static final ToLongFunction counterTimeout = DatabaseDescriptor::getCounterWriteRpcTimeout; + static final ToLongFunction truncateTimeout = DatabaseDescriptor::getTruncateRpcTimeout; + static final ToLongFunction pingTimeout = DatabaseDescriptor::getPingTimeout; + static final ToLongFunction longTimeout = units -> Math.max(DatabaseDescriptor.getRpcTimeout(units), units.convert(5L, TimeUnit.MINUTES)); + static final ToLongFunction noTimeout = units -> { throw new IllegalStateException(); }; +} \ No newline at end of file diff --git a/src/java/org/apache/cassandra/net/WriteCallbackInfo.java b/src/java/org/apache/cassandra/net/WriteCallbackInfo.java deleted file mode 100644 index c54e7dcba96f..000000000000 --- a/src/java/org/apache/cassandra/net/WriteCallbackInfo.java +++ /dev/null @@ -1,80 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package org.apache.cassandra.net; - -import org.apache.cassandra.db.ConsistencyLevel; -import org.apache.cassandra.db.Mutation; -import org.apache.cassandra.io.IVersionedSerializer; -import org.apache.cassandra.locator.Replica; -import org.apache.cassandra.service.StorageProxy; -import org.apache.cassandra.service.paxos.Commit; -import org.apache.cassandra.utils.FBUtilities; - -public class WriteCallbackInfo extends CallbackInfo -{ - // either a Mutation, or a Paxos Commit (MessageOut) - private final Object mutation; - private final Replica replica; - - public WriteCallbackInfo(Replica replica, - IAsyncCallback callback, - MessageOut message, - IVersionedSerializer serializer, - ConsistencyLevel consistencyLevel, - boolean allowHints) - { - super(replica.endpoint(), callback, serializer, true); - assert message != null; - this.mutation = shouldHint(allowHints, message, consistencyLevel); - //Local writes shouldn't go through messaging service (https://issues.apache.org/jira/browse/CASSANDRA-10477) - assert (!target.equals(FBUtilities.getBroadcastAddressAndPort())); - this.replica = replica; - } - - public boolean shouldHint() - { - return mutation != null && StorageProxy.shouldHint(replica); - } - - public Replica getReplica() - { - return replica; - } - - public Mutation mutation() - { - return getMutation(mutation); - } - - private static Mutation getMutation(Object object) - { - assert object instanceof Commit || object instanceof Mutation : object; - return object instanceof Commit ? ((Commit) object).makeMutation() - : (Mutation) object; - } - - private static Object shouldHint(boolean allowHints, MessageOut sentMessage, ConsistencyLevel consistencyLevel) - { - return allowHints - && sentMessage.verb != MessagingService.Verb.COUNTER_MUTATION - && consistencyLevel != ConsistencyLevel.ANY - ? sentMessage.payload : null; - } - -} diff --git a/src/java/org/apache/cassandra/net/async/BaseMessageInHandler.java b/src/java/org/apache/cassandra/net/async/BaseMessageInHandler.java deleted file mode 100644 index 2f2a9739acc5..000000000000 --- a/src/java/org/apache/cassandra/net/async/BaseMessageInHandler.java +++ /dev/null @@ -1,197 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.cassandra.net.async; - -import java.io.EOFException; -import java.io.IOException; -import java.util.Collections; -import java.util.List; -import java.util.Map; -import java.util.function.BiConsumer; - -import com.google.common.annotations.VisibleForTesting; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import io.netty.buffer.ByteBuf; -import io.netty.channel.ChannelHandlerContext; -import io.netty.handler.codec.ByteToMessageDecoder; -import org.apache.cassandra.db.monitoring.ApproximateTime; -import org.apache.cassandra.exceptions.UnknownTableException; -import org.apache.cassandra.locator.InetAddressAndPort; -import org.apache.cassandra.net.MessageIn; -import org.apache.cassandra.net.MessagingService; -import org.apache.cassandra.net.ParameterType; - -/** - * Parses out individual messages from the incoming buffers. Each message, both header and payload, is incrementally built up - * from the available input data, then passed to the {@link #messageConsumer}. - * - * Note: this class derives from {@link ByteToMessageDecoder} to take advantage of the {@link ByteToMessageDecoder.Cumulator} - * behavior across {@link #decode(ChannelHandlerContext, ByteBuf, List)} invocations. That way we don't have to maintain - * the not-fully consumed {@link ByteBuf}s. - */ -public abstract class BaseMessageInHandler extends ByteToMessageDecoder -{ - public static final Logger logger = LoggerFactory.getLogger(BaseMessageInHandler.class); - - enum State - { - READ_FIRST_CHUNK, - READ_IP_ADDRESS, - READ_VERB, - READ_PARAMETERS_SIZE, - READ_PARAMETERS_DATA, - READ_PAYLOAD_SIZE, - READ_PAYLOAD, - CLOSED - } - - /** - * The byte count for magic, msg id, timestamp values. - */ - @VisibleForTesting - static final int FIRST_SECTION_BYTE_COUNT = 12; - - static final int VERB_LENGTH = Integer.BYTES; - - /** - * The default target for consuming deserialized {@link MessageIn}. - */ - static final BiConsumer MESSAGING_SERVICE_CONSUMER = (messageIn, id) -> MessagingService.instance().receive(messageIn, id); - - /** - * Abstracts out depending directly on {@link MessagingService#receive(MessageIn, int)}; this makes tests more sane - * as they don't require nor trigger the entire message processing circus. - */ - final BiConsumer messageConsumer; - - final InetAddressAndPort peer; - final int messagingVersion; - - protected State state; - - public BaseMessageInHandler(InetAddressAndPort peer, int messagingVersion, BiConsumer messageConsumer) - { - this.peer = peer; - this.messagingVersion = messagingVersion; - this.messageConsumer = messageConsumer; - } - - // redeclared here to make the method public (for testing) - @VisibleForTesting - public void decode(ChannelHandlerContext ctx, ByteBuf in, List out) throws Exception - { - if (state == State.CLOSED) - { - in.skipBytes(in.readableBytes()); - return; - } - - try - { - handleDecode(ctx, in, out); - } - catch (Exception e) - { - // prevent any future attempts at reading messages from any inbound buffers, as we're already in a bad state - state = State.CLOSED; - - // force the buffer to appear to be consumed, thereby exiting the ByteToMessageDecoder.callDecode() loop, - // and other paths in that class, more efficiently - in.skipBytes(in.readableBytes()); - - // throwing the exception up causes the ByteToMessageDecoder.callDecode() loop to exit. if we don't do that, - // we'll keep trying to process data out of the last received buffer (and it'll be really, really wrong) - throw e; - } - } - - public abstract void handleDecode(ChannelHandlerContext ctx, ByteBuf in, List out) throws Exception; - - MessageHeader readFirstChunk(ByteBuf in) throws IOException - { - if (in.readableBytes() < FIRST_SECTION_BYTE_COUNT) - return null; - MessagingService.validateMagic(in.readInt()); - MessageHeader messageHeader = new MessageInHandler.MessageHeader(); - messageHeader.messageId = in.readInt(); - int messageTimestamp = in.readInt(); // make sure to read the sent timestamp, even if DatabaseDescriptor.hasCrossNodeTimeout() is not enabled - messageHeader.constructionTime = MessageIn.deriveConstructionTime(peer, messageTimestamp, ApproximateTime.currentTimeMillis()); - - return messageHeader; - } - - @Override - public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) - { - if (cause instanceof EOFException) - logger.trace("eof reading from socket; closing", cause); - else if (cause instanceof UnknownTableException) - logger.warn(" Got message from unknown table while reading from socket {}[{}]; closing", - ctx.channel().remoteAddress(), ctx.channel().id(), cause); - else if (cause instanceof IOException) - logger.trace("IOException reading from socket; closing", cause); - else - logger.warn("Unexpected exception caught in inbound channel pipeline from {}[{}]", - ctx.channel().remoteAddress(), ctx.channel().id(), cause); - - ctx.close(); - } - - @Override - public void channelInactive(ChannelHandlerContext ctx) - { - logger.trace("received channel closed message for peer {} on local addr {}", ctx.channel().remoteAddress(), ctx.channel().localAddress()); - state = State.CLOSED; - ctx.fireChannelInactive(); - } - - // should ony be used for testing!!! - @VisibleForTesting - abstract MessageHeader getMessageHeader(); - - /** - * A simple struct to hold the message header data as it is being built up. - */ - static class MessageHeader - { - int messageId; - long constructionTime; - InetAddressAndPort from; - MessagingService.Verb verb; - int payloadSize; - - Map parameters = Collections.emptyMap(); - - /** - * Length of the parameter data. If the message's version is {@link MessagingService#VERSION_40} or higher, - * this value is the total number of header bytes; else, for legacy messaging, this is the number of - * key/value entries in the header. - */ - int parameterLength; - } - - // for testing purposes only!!! - @VisibleForTesting - public State getState() - { - return state; - } -} diff --git a/src/java/org/apache/cassandra/net/async/ByteBufDataOutputPlus.java b/src/java/org/apache/cassandra/net/async/ByteBufDataOutputPlus.java deleted file mode 100644 index a77cb0713e84..000000000000 --- a/src/java/org/apache/cassandra/net/async/ByteBufDataOutputPlus.java +++ /dev/null @@ -1,139 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.cassandra.net.async; - -import java.io.IOException; -import java.nio.ByteBuffer; -import java.nio.channels.WritableByteChannel; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufOutputStream; -import org.apache.cassandra.io.util.CheckedFunction; -import org.apache.cassandra.io.util.DataOutputPlus; -import org.apache.cassandra.io.util.Memory; -import org.apache.cassandra.io.util.UnbufferedDataOutputStreamPlus; -import org.apache.cassandra.utils.memory.MemoryUtil; -import org.apache.cassandra.utils.vint.VIntCoding; - -/** - * A {@link DataOutputPlus} that uses a {@link ByteBuf} as a backing buffer. This class is completely thread unsafe and - * it is expected that the backing buffer is sized correctly for all the writes you want to do (or the buffer needs - * to be growable). - */ -public class ByteBufDataOutputPlus extends ByteBufOutputStream implements DataOutputPlus -{ - private final ByteBuf buffer; - - /** - * ByteBuffer to use for defensive copies of direct {@link ByteBuffer}s - see {@link #write(ByteBuffer)}. - */ - private final ByteBuffer hollowBuffer = MemoryUtil.getHollowDirectByteBuffer(); - - public ByteBufDataOutputPlus(ByteBuf buffer) - { - super(buffer); - this.buffer = buffer; - } - - /** - * {@inheritDoc} - "write the buffer without modifying its position" - * - * Unfortunately, netty's {@link ByteBuf#writeBytes(ByteBuffer)} modifies the byteBuffer's position, - * and that is unsafe in our world wrt multithreading. Hence we need to be careful: reference the backing array - * on heap ByteBuffers, and use a reusable "hollow" ByteBuffer ({@link #hollowBuffer}) for direct ByteBuffers. - */ - @Override - public void write(ByteBuffer byteBuffer) throws IOException - { - if (byteBuffer.hasArray()) - { - write(byteBuffer.array(), byteBuffer.arrayOffset() + byteBuffer.position(), byteBuffer.remaining()); - } - else - { - assert byteBuffer.isDirect(); - MemoryUtil.duplicateDirectByteBuffer(byteBuffer, hollowBuffer); - buffer.writeBytes(hollowBuffer); - } - } - - @Override - public void write(Memory memory, long offset, long length) throws IOException - { - for (ByteBuffer buffer : memory.asByteBuffers(offset, length)) - write(buffer); - } - - @Override - public R applyToChannel(CheckedFunction c) throws IOException - { - throw new UnsupportedOperationException(); - } - - @Override - public void writeVInt(long v) throws IOException - { - writeUnsignedVInt(VIntCoding.encodeZigZag64(v)); - } - - @Override - public void writeUnsignedVInt(long v) throws IOException - { - int size = VIntCoding.computeUnsignedVIntSize(v); - if (size == 1) - { - buffer.writeByte((byte) (v & 0xFF)); - return; - } - - buffer.writeBytes(VIntCoding.encodeVInt(v, size), 0, size); - } - - @Override - public void write(int b) throws IOException - { - buffer.writeByte((byte) (b & 0xFF)); - } - - @Override - public void writeByte(int v) throws IOException - { - buffer.writeByte((byte) (v & 0xFF)); - } - - @Override - public void writeBytes(String s) throws IOException - { - for (int index = 0; index < s.length(); index++) - buffer.writeByte(s.charAt(index) & 0xFF); - } - - @Override - public void writeChars(String s) throws IOException - { - for (int index = 0; index < s.length(); index++) - buffer.writeChar(s.charAt(index)); - } - - @Override - public void writeUTF(String s) throws IOException - { - UnbufferedDataOutputStreamPlus.writeUTF(s, this); - } -} diff --git a/src/java/org/apache/cassandra/net/async/ByteBufDataOutputStreamPlus.java b/src/java/org/apache/cassandra/net/async/ByteBufDataOutputStreamPlus.java deleted file mode 100644 index 777bc3e73522..000000000000 --- a/src/java/org/apache/cassandra/net/async/ByteBufDataOutputStreamPlus.java +++ /dev/null @@ -1,243 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.cassandra.net.async; - -import java.io.IOException; -import java.nio.ByteBuffer; -import java.nio.channels.FileChannel; -import java.nio.channels.WritableByteChannel; -import java.util.concurrent.Semaphore; -import java.util.concurrent.TimeUnit; - -import com.google.common.util.concurrent.Uninterruptibles; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufAllocator; -import io.netty.buffer.Unpooled; -import io.netty.channel.Channel; -import io.netty.channel.ChannelFuture; -import io.netty.util.concurrent.Future; -import org.apache.cassandra.io.util.BufferedDataOutputStreamPlus; -import org.apache.cassandra.io.util.DataOutputStreamPlus; -import org.apache.cassandra.io.util.FileUtils; -import org.apache.cassandra.streaming.StreamManager.StreamRateLimiter; -import org.apache.cassandra.streaming.StreamSession; - -/** - * A {@link DataOutputStreamPlus} that writes to a {@link ByteBuf}. The novelty here is that all writes - * actually get written in to a {@link ByteBuffer} that shares a backing buffer with a {@link ByteBuf}. - * The trick to do that is allocate the ByteBuf, get a ByteBuffer from it by calling {@link ByteBuf#nioBuffer()}, - * and passing that to the super class as {@link #buffer}. When the {@link #buffer} is full or {@link #doFlush(int)} - * is invoked, the {@link #currentBuf} is published to the netty channel. - */ -public class ByteBufDataOutputStreamPlus extends BufferedDataOutputStreamPlus -{ - private final StreamSession session; - private final Channel channel; - private final int bufferSize; - private final Logger logger = LoggerFactory.getLogger(ByteBufDataOutputStreamPlus.class); - - /** - * Tracks how many bytes we've written to the netty channel. This more or less follows the channel's - * high/low water marks and ultimately the 'writablility' status of the channel. Unfortunately there's - * no notification mechanism that can poke a producer to let it know when the channel becomes writable - * (after it was unwritable); hence, the use of a {@link Semaphore}. - */ - private final Semaphore channelRateLimiter; - - /** - * This *must* be the owning {@link ByteBuf} for the {@link BufferedDataOutputStreamPlus#buffer} - */ - private ByteBuf currentBuf; - - private ByteBufDataOutputStreamPlus(StreamSession session, Channel channel, ByteBuf buffer, int bufferSize) - { - super(buffer.nioBuffer(0, bufferSize)); - this.session = session; - this.channel = channel; - this.currentBuf = buffer; - this.bufferSize = bufferSize; - channelRateLimiter = new Semaphore(channel.config().getWriteBufferHighWaterMark(), true); - } - - @Override - protected WritableByteChannel newDefaultChannel() - { - return new WritableByteChannel() - { - @Override - public int write(ByteBuffer src) throws IOException - { - assert src == buffer; - int size = src.position(); - doFlush(size); - return size; - } - - @Override - public boolean isOpen() - { - return channel.isOpen(); - } - - @Override - public void close() - { } - }; - } - - public static ByteBufDataOutputStreamPlus create(StreamSession session, Channel channel, int bufferSize) - { - ByteBuf buf = channel.alloc().directBuffer(bufferSize, bufferSize); - return new ByteBufDataOutputStreamPlus(session, channel, buf, bufferSize); - } - - /** - * Writes the incoming buffer directly to the backing {@link #channel}, without copying to the intermediate {@link #buffer}. - */ - public ChannelFuture writeToChannel(ByteBuf buf) throws IOException - { - doFlush(buffer.position()); - - int byteCount = buf.readableBytes(); - - if (!Uninterruptibles.tryAcquireUninterruptibly(channelRateLimiter, byteCount, 5, TimeUnit.MINUTES)) - throw new IOException(String.format("outbound channel was not writable. Failed to acquire sufficient permits %d", byteCount)); - - // the (possibly naive) assumption that we should always flush after each incoming buf - ChannelFuture channelFuture = channel.writeAndFlush(buf); - channelFuture.addListener(future -> handleBuffer(future, byteCount)); - return channelFuture; - } - - /** - * Writes the incoming buffer directly to the backing {@link #channel}, without copying to the intermediate {@link #buffer}. - * The incoming buffer will be automatically released when the netty channel invokes the listeners of success/failure to - * send the buffer. - */ - public ChannelFuture writeToChannel(ByteBuffer buffer) throws IOException - { - ChannelFuture channelFuture = writeToChannel(Unpooled.wrappedBuffer(buffer)); - channelFuture.addListener(future -> FileUtils.clean(buffer)); - return channelFuture; - } - - /** - * Writes all data in file channel to stream BUFFER_SIZE at a time. - * Closes file channel when done - * - * @param f - * @return number of bytes transferred - * @throws IOException - */ - public long writeToChannel(FileChannel f, StreamRateLimiter limiter) throws IOException - { - final long length = f.size(); - long bytesTransferred = 0; - - try - { - while (bytesTransferred < length) - { - int toRead = (int) Math.min(bufferSize, length - bytesTransferred); - NonClosingDefaultFileRegion fileRegion = new NonClosingDefaultFileRegion(f, bytesTransferred, toRead); - - if (!Uninterruptibles.tryAcquireUninterruptibly(channelRateLimiter, toRead, 5, TimeUnit.MINUTES)) - throw new IOException(String.format("outbound channel was not writable. Failed to acquire sufficient permits %d", toRead)); - - limiter.acquire(toRead); - - bytesTransferred += toRead; - final boolean shouldClose = (bytesTransferred == length); // this is the last buffer, can safely close channel - - channel.writeAndFlush(fileRegion).addListener(future -> { - handleBuffer(future, toRead); - - if ((shouldClose || !future.isSuccess()) && f.isOpen()) - f.close(); - }); - logger.trace("{} of {} (toRead {} cs {})", bytesTransferred, length, toRead, f.isOpen()); - } - - return bytesTransferred; - } catch (Exception e) - { - if (f.isOpen()) - f.close(); - - throw e; - } - } - - @Override - protected void doFlush(int count) throws IOException - { - // flush the current backing write buffer only if there's any pending data - if (buffer.position() > 0 && channel.isOpen()) - { - int byteCount = buffer.position(); - currentBuf.writerIndex(byteCount); - - if (!Uninterruptibles.tryAcquireUninterruptibly(channelRateLimiter, byteCount, 2, TimeUnit.MINUTES)) - throw new IOException(String.format("outbound channel was not writable. Failed to acquire sufficient permits %d", byteCount)); - - channel.writeAndFlush(currentBuf).addListener(future -> handleBuffer(future, byteCount)); - currentBuf = channel.alloc().directBuffer(bufferSize, bufferSize); - buffer = currentBuf.nioBuffer(0, bufferSize); - } - } - - /** - * Handles the result of publishing a buffer to the channel. - * - * Note: this will be executed on the event loop. - */ - private void handleBuffer(Future future, int bytesWritten) - { - channelRateLimiter.release(bytesWritten); - logger.trace("bytesWritten {} {} because {}", bytesWritten, (future.isSuccess() == true) ? "Succeeded" : "Failed", future.cause()); - if (!future.isSuccess() && channel.isOpen()) - session.onError(future.cause()); - } - - public ByteBufAllocator getAllocator() - { - return channel.alloc(); - } - - /** - * {@inheritDoc} - * - * Flush any last buffered (if the channel is open), and release any buffers. *Not* responsible for closing - * the netty channel as we might use it again for transferring more files. - * - * Note: should be called on the producer thread, not the netty event loop. - */ - @Override - public void close() throws IOException - { - doFlush(0); - if (currentBuf.refCnt() > 0) - currentBuf.release(); - currentBuf = null; - buffer = null; - } -} diff --git a/src/java/org/apache/cassandra/net/async/ChannelWriter.java b/src/java/org/apache/cassandra/net/async/ChannelWriter.java deleted file mode 100644 index e9847366c2e5..000000000000 --- a/src/java/org/apache/cassandra/net/async/ChannelWriter.java +++ /dev/null @@ -1,418 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.cassandra.net.async; - -import java.io.IOException; -import java.util.Optional; -import java.util.Queue; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicLong; -import java.util.function.Consumer; - -import com.google.common.annotations.VisibleForTesting; - -import io.netty.buffer.Unpooled; -import io.netty.channel.Channel; -import io.netty.channel.ChannelFuture; -import io.netty.channel.ChannelFutureListener; -import io.netty.channel.ChannelHandlerContext; -import io.netty.channel.ChannelOutboundBuffer; -import io.netty.channel.ChannelPromise; -import io.netty.channel.MessageSizeEstimator; -import io.netty.handler.timeout.IdleStateEvent; -import io.netty.util.Attribute; -import io.netty.util.AttributeKey; -import io.netty.util.concurrent.Future; -import org.apache.cassandra.config.DatabaseDescriptor; -import org.apache.cassandra.utils.CoalescingStrategies; -import org.apache.cassandra.utils.CoalescingStrategies.CoalescingStrategy; - -/** - * Represents a ready and post-handshake channel that can send outbound messages. This class groups a netty channel - * with any other channel-related information we track and, most importantly, handles the details on when the channel is flushed. - * - *

Flushing

- * - * We don't flush to the socket on every message as it's a bit of a performance drag (making the system call, copying - * the buffer, sending out a small packet). Thus, by waiting until we have a decent chunk of data (for some definition - * of 'decent'), we can achieve better efficiency and improved performance (yay!). - *

- * When to flush mainly depends on whether we use message coalescing or not (see {@link CoalescingStrategies}). - *

- * Note that the callback functions are invoked on the netty event loop, which is (in almost all cases) different - * from the thread that will be invoking {@link #write(QueuedMessage, boolean)}. - * - *

Flushing without coalescing

- * - * When no coalescing is in effect, we want to send new message "right away". However, as said above, flushing after - * every message would be particularly inefficient when there is lots of message in our sending queue, and so in - * practice we want to flush in 2 cases: - * 1) After any message if there is no pending message in the send queue. - * 2) When we've filled up or exceeded the netty outbound buffer (see {@link ChannelOutboundBuffer}) - *

- * The second part is relatively simple and handled generically in {@link MessageOutHandler#write(ChannelHandlerContext, Object, ChannelPromise)} [1]. - * The first part however is made a little more complicated by how netty's event loop executes. It is woken up by - * external callers to the channel invoking a flush, via either {@link Channel#flush} or one of the {@link Channel#writeAndFlush} - * methods [2]. So a plain {@link Channel#write} will only queue the message in the channel, and not wake up the event loop. - *

- * This means we don't want to simply call {@link Channel#write} as we want the message processed immediately. But we - * also don't want to flush on every message if there is more in the sending queue, so simply calling - * {@link Channel#writeAndFlush} isn't completely appropriate either. In practice, we handle this by calling - * {@link Channel#writeAndFlush} (so the netty event loop does wake up), but we override the flush behavior so - * it actually only flushes if there are no pending messages (see how {@link MessageOutHandler#flush} delegates the flushing - * decision back to this class through {@link #onTriggeredFlush}, and how {@link SimpleChannelWriter} makes this a no-op; - * instead {@link SimpleChannelWriter} flushes after any message if there are no more pending ones in - * {@link #onMessageProcessed}). - * - *

Flushing with coalescing

- * - * The goal of coalescing is to (artificially) delay the flushing of data in order to aggregate even more data before - * sending a group of packets out. So we don't want to flush after messages even if there is no pending messages in the - * sending queue, but we rather want to delegate the decision on when to flush to the {@link CoalescingStrategy}. In - * pratice, when coalescing is enabled we will flush in 2 cases: - * 1) When the coalescing strategies decides that we should. - * 2) When we've filled up or exceeded the netty outbound buffer ({@link ChannelOutboundBuffer}), exactly like in the - * no coalescing case. - *

- * The second part is handled exactly like in the no coalescing case, see above. - * The first part is handled by {@link CoalescingChannelWriter#write(QueuedMessage, boolean)}. Whenever a message is sent, we check - * if a flush has been already scheduled by the coalescing strategy. If one has, we're done, otherwise we ask the - * strategy when the next flush should happen and schedule one. - * - *

Message timeouts and retries

- * - * The main outward-facing method is {@link #write(QueuedMessage, boolean)}, where callers pass a - * {@link QueuedMessage}. If a message times out, as defined in {@link QueuedMessage#isTimedOut()}, - * the message listener {@link #handleMessageFuture(Future, QueuedMessage, boolean)} is invoked - * with the cause being a {@link ExpiredException}. The message is not retried and it is dropped on the floor. - *

- * If there is some {@link IOException} on the socket after the message has been written to the netty channel, - * the message listener {@link #handleMessageFuture(Future, QueuedMessage, boolean)} is invoked - * and 1) we check to see if the connection should be re-established, and 2) possibly createRetry the message. - * - *

Failures

- * - *

Failure to make progress sending bytes

- * If we are unable to make progress sending messages, we'll receive a netty notification - * ({@link IdleStateEvent}) at {@link MessageOutHandler#userEventTriggered(ChannelHandlerContext, Object)}. - * We then want to close the socket/channel, and purge any messages in {@link OutboundMessagingConnection#backlog} - * to try to free up memory as quickly as possible. Any messages in the netty pipeline will be marked as fail - * (as we close the channel), but {@link MessageOutHandler#userEventTriggered(ChannelHandlerContext, Object)} also - * sets a channel attribute, {@link #PURGE_MESSAGES_CHANNEL_ATTR} to true. This is essentially as special flag - * that we can look at in the promise handler code ({@link #handleMessageFuture(Future, QueuedMessage, boolean)}) - * to indicate that any backlog should be thrown away. - * - *

Notes

- * [1] For those desperately interested, and only after you've read the entire class-level doc: You can register a custom - * {@link MessageSizeEstimator} with a netty channel. When a message is written to the channel, it will check the - * message size, and if the max ({@link ChannelOutboundBuffer}) size will be exceeded, a task to signal the "channel - * writability changed" will be executed in the channel. That task, however, will wake up the event loop. - * Thus if coalescing is enabled, the event loop will wake up prematurely and process (and possibly flush!) the messages - * currently in the queue, thus defeating an aspect of coalescing. Hence, we're not using that feature of netty. - * [2]: The netty event loop is also woken up by it's internal timeout on the epoll_wait() system call. - */ -abstract class ChannelWriter -{ - /** - * A netty channel {@link Attribute} to indicate, when a channel is closed, any backlogged messages should be purged, - * as well. See the class-level documentation for more information. - */ - static final AttributeKey PURGE_MESSAGES_CHANNEL_ATTR = AttributeKey.newInstance("purgeMessages"); - - protected final Channel channel; - private volatile boolean closed; - - /** Number of currently pending messages on this channel. */ - final AtomicLong pendingMessageCount = new AtomicLong(0); - - /** - * A consuming function that handles the result of each message sent. - */ - private final Consumer messageResultConsumer; - - /** - * A reusable instance to avoid creating garbage on preciessing the result of every message sent. - * As we have the guarantee that the netty evet loop is single threaded, there should be no contention over this - * instance, as long as it (not it's state) is shared across threads. - */ - private final MessageResult messageResult = new MessageResult(); - - protected ChannelWriter(Channel channel, Consumer messageResultConsumer) - { - this.channel = channel; - this.messageResultConsumer = messageResultConsumer; - channel.attr(PURGE_MESSAGES_CHANNEL_ATTR).set(false); - } - - /** - * Creates a new {@link ChannelWriter} using the (assumed properly connected) provided channel, and using coalescing - * based on the provided strategy. - */ - static ChannelWriter create(Channel channel, Consumer messageResultConsumer, Optional coalescingStrategy) - { - return coalescingStrategy.isPresent() - ? new CoalescingChannelWriter(channel, messageResultConsumer, coalescingStrategy.get()) - : new SimpleChannelWriter(channel, messageResultConsumer); - } - - /** - * Writes a message to this {@link ChannelWriter} if the channel is writable. - *

- * We always want to write to the channel *unless* it's not writable yet still open. - * If the channel is closed, the promise will be notifed as a fail (due to channel closed), - * and let the handler ({@link #handleMessageFuture(Future, QueuedMessage, boolean)}) - * do the reconnect magic/dance. Thus we simplify when to reconnect by not burdening the (concurrent) callers - * of this method, and instead keep it all in the future handler/event loop (which is single threaded). - * - * @param message the message to write/send. - * @param checkWritability a flag to indicate if the status of the channel should be checked before passing - * the message on to the {@link #channel}. - * @return true if the message was written to the channel; else, false. - */ - boolean write(QueuedMessage message, boolean checkWritability) - { - if ( (checkWritability && (channel.isWritable()) || !channel.isOpen()) || !checkWritability) - { - write0(message).addListener(f -> handleMessageFuture(f, message, true)); - return true; - } - return false; - } - - /** - * Handles the future of sending a particular message on this {@link ChannelWriter}. - *

- * Note: this is called from the netty event loop, so there is no race across multiple execution of this method. - */ - @VisibleForTesting - void handleMessageFuture(Future future, QueuedMessage msg, boolean allowReconnect) - { - messageResult.setAll(this, msg, future, allowReconnect); - messageResultConsumer.accept(messageResult); - messageResult.clearAll(); - } - - boolean shouldPurgeBacklog() - { - if (!channel.attr(PURGE_MESSAGES_CHANNEL_ATTR).get()) - return false; - - channel.attr(PURGE_MESSAGES_CHANNEL_ATTR).set(false); - return true; - } - - /** - * Writes a backlog of message to this {@link ChannelWriter}. This is mostly equivalent to calling - * {@link #write(QueuedMessage, boolean)} for every message of the provided backlog queue, but - * it ignores any coalescing, triggering a flush only once after all messages have been sent. - * - * @param backlog the backlog of message to send. - * @return the count of items written to the channel from the queue. - */ - int writeBacklog(Queue backlog, boolean allowReconnect) - { - int count = 0; - while (true) - { - if (!channel.isWritable()) - break; - - QueuedMessage msg = backlog.poll(); - if (msg == null) - break; - - pendingMessageCount.incrementAndGet(); - ChannelFuture future = channel.write(msg); - future.addListener(f -> handleMessageFuture(f, msg, allowReconnect)); - count++; - } - - // as this is an infrequent operation, don't bother coordinating with the instance-level flush task - if (count > 0) - channel.flush(); - - return count; - } - - void close() - { - if (closed) - return; - - closed = true; - channel.close(); - } - - long pendingMessageCount() - { - return pendingMessageCount.get(); - } - - /** - * Close the underlying channel but only after having make sure every pending message has been properly sent. - */ - void softClose() - { - if (closed) - return; - - closed = true; - channel.writeAndFlush(Unpooled.EMPTY_BUFFER).addListener(ChannelFutureListener.CLOSE); - } - - @VisibleForTesting - boolean isClosed() - { - return closed; - } - - /** - * Write the message to the {@link #channel}. - *

- * Note: this method, in almost all cases, is invoked from an app-level writing thread, not the netty event loop. - */ - protected abstract ChannelFuture write0(QueuedMessage message); - - /** - * Invoked after a message has been processed in the pipeline. Should only be used for essential bookkeeping operations. - *

- * Note: this method is invoked on the netty event loop. - */ - abstract void onMessageProcessed(ChannelHandlerContext ctx); - - /** - * Invoked when pipeline receives a flush request. - *

- * Note: this method is invoked on the netty event loop. - */ - abstract void onTriggeredFlush(ChannelHandlerContext ctx); - - /** - * Handles the non-coalescing flush case. - */ - @VisibleForTesting - static class SimpleChannelWriter extends ChannelWriter - { - private SimpleChannelWriter(Channel channel, Consumer messageResultConsumer) - { - super(channel, messageResultConsumer); - } - - protected ChannelFuture write0(QueuedMessage message) - { - pendingMessageCount.incrementAndGet(); - // We don't truly want to flush on every message but we do want to wake-up the netty event loop for the - // channel so the message is processed right away, which is why we use writeAndFlush. This won't actually - // flush, though, because onTriggeredFlush, which MessageOutHandler delegates to, does nothing. We will - // flush after the message is processed though if there is no pending one due to onMessageProcessed. - // See the class javadoc for context and much more details. - return channel.writeAndFlush(message); - } - - void onMessageProcessed(ChannelHandlerContext ctx) - { - if (pendingMessageCount.decrementAndGet() == 0) - ctx.flush(); - } - - void onTriggeredFlush(ChannelHandlerContext ctx) - { - // Don't actually flush on "normal" flush calls to the channel. - } - } - - /** - * Handles the coalescing flush case. - */ - @VisibleForTesting - static class CoalescingChannelWriter extends ChannelWriter - { - private static final int MIN_MESSAGES_FOR_COALESCE = DatabaseDescriptor.getOtcCoalescingEnoughCoalescedMessages(); - - private final CoalescingStrategy strategy; - private final int minMessagesForCoalesce; - - @VisibleForTesting - final AtomicBoolean scheduledFlush = new AtomicBoolean(false); - - CoalescingChannelWriter(Channel channel, Consumer messageResultConsumer, CoalescingStrategy strategy) - { - this (channel, messageResultConsumer, strategy, MIN_MESSAGES_FOR_COALESCE); - } - - @VisibleForTesting - CoalescingChannelWriter(Channel channel, Consumer messageResultConsumer, CoalescingStrategy strategy, int minMessagesForCoalesce) - { - super(channel, messageResultConsumer); - this.strategy = strategy; - this.minMessagesForCoalesce = minMessagesForCoalesce; - } - - protected ChannelFuture write0(QueuedMessage message) - { - long pendingCount = pendingMessageCount.incrementAndGet(); - ChannelFuture future = channel.write(message); - strategy.newArrival(message); - - // if we lost the race to set the state, simply write to the channel (no flush) - if (!scheduledFlush.compareAndSet(false, true)) - return future; - - long flushDelayNanos; - // if we've hit the minimum number of messages for coalescing or we've run out of coalesce time, flush. - // note: we check the exact count, instead of greater than or equal to, of message here to prevent a flush task - // for each message (if there's messages coming in on multiple threads). There will be, of course, races - // with the consumer decrementing the pending counter, but that's still less excessive flushes. - if (pendingCount == minMessagesForCoalesce || (flushDelayNanos = strategy.currentCoalescingTimeNanos()) <= 0) - { - scheduledFlush.set(false); - channel.flush(); - } - else - { - // calling schedule() on the eventLoop will force it to wake up (if not already executing) and schedule the task - channel.eventLoop().schedule(() -> { - // NOTE: this executes on the event loop - scheduledFlush.set(false); - // we execute() the flush() as an additional task rather than immediately in-line as there is a - // race condition when this task runs (executing on the event loop) and a thread that writes the channel (top of this method). - // If this task is picked up but before the scheduledFlush falg is flipped, the other thread writes - // and then checks the scheduledFlush (which is still true) and exits. - // This task changes the flag and if it calls flush() in-line, and netty flushs everything immediately (that is, what's been serialized) - // to the transport as we're on the event loop. The other thread's write became a task that executes *after* this task in the netty queue, - // and if there's not a subsequent followup flush scheduled, that write can be orphaned until another write comes in. - channel.eventLoop().execute(channel::flush); - }, flushDelayNanos, TimeUnit.NANOSECONDS); - } - return future; - } - - void onMessageProcessed(ChannelHandlerContext ctx) - { - pendingMessageCount.decrementAndGet(); - } - - void onTriggeredFlush(ChannelHandlerContext ctx) - { - // When coalescing, obey the flush calls normally - ctx.flush(); - } - } -} diff --git a/src/java/org/apache/cassandra/net/async/HandshakeProtocol.java b/src/java/org/apache/cassandra/net/async/HandshakeProtocol.java deleted file mode 100644 index ebf26bddde4d..000000000000 --- a/src/java/org/apache/cassandra/net/async/HandshakeProtocol.java +++ /dev/null @@ -1,314 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.cassandra.net.async; - -import java.io.DataInput; -import java.io.DataOutput; -import java.io.IOException; -import java.net.InetAddress; -import java.util.Objects; - -import com.google.common.annotations.VisibleForTesting; -import com.google.common.primitives.Ints; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufAllocator; -import io.netty.buffer.ByteBufInputStream; -import io.netty.buffer.ByteBufOutputStream; -import org.apache.cassandra.io.util.DataInputPlus; -import org.apache.cassandra.io.util.DataOutputPlus; -import org.apache.cassandra.locator.InetAddressAndPort; -import org.apache.cassandra.net.CompactEndpointSerializationHelper; -import org.apache.cassandra.net.MessagingService; - -/** - * Messages for the handshake phase of the internode protocol. - *

- * The handshake's main purpose is to establish a protocol version that both side can talk, as well as exchanging a few connection - * options/parameters. The handshake is composed of 3 messages, the first being sent by the initiator of the connection. The other - * side then answer with the 2nd message. At that point, if a version mismatch is detected by the connection initiator, - * it will simply disconnect and reconnect with a more appropriate version. But if the version is acceptable, the connection - * initiator sends the third message of the protocol, after which it considers the connection ready. - *

- * See below for a more precise description of each of those 3 messages. - *

- * Note that this handshake protocol doesn't fully apply to streaming. For streaming, only the first message is sent, - * after which the streaming protocol takes over (not documented here) - */ -public class HandshakeProtocol -{ - /** - * The initial message sent when a node creates a new connection to a remote peer. This message contains: - * 1) the {@link MessagingService#PROTOCOL_MAGIC} number (4 bytes). - * 2) the connection flags (4 bytes), which encodes: - * - the version the initiator thinks should be used for the connection (in practice, either the initiator - * version if it's the first time we connect to that remote since startup, or the last version known for that - * peer otherwise). - * - the "mode" of the connection: whether it is for streaming or for messaging. - * - whether compression should be used or not (if it is, compression is enabled _after_ the last message of the - * handshake has been sent). - *

- * More precisely, connection flags: - *

-     * {@code
-     *                      1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 3 3
-     *  0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
-     * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
-     * |U U C M       |                |                               |
-     * |N N M O       |     VERSION    |             unused            |
-     * |U U P D       |                |                               |
-     * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
-     * }
-     * 
- * UNU - unused bits lowest two bits; from a historical note: used to be "serializer type," which was always Binary - * CMP - compression enabled bit - * MOD - connection mode. If the bit is on, the connection is for streaming; if the bit is off, it is for inter-node messaging. - * VERSION - if a streaming connection, indicates the streaming protocol version {@link org.apache.cassandra.streaming.messages.StreamMessage#CURRENT_VERSION}; - * if a messaging connection, indicates the messaging protocol version the initiator *thinks* should be used. - */ - public static class FirstHandshakeMessage - { - /** Contains the PROTOCOL_MAGIC (int) and the flags (int). */ - private static final int LENGTH = 8; - - final int messagingVersion; - final NettyFactory.Mode mode; - final boolean compressionEnabled; - - public FirstHandshakeMessage(int messagingVersion, NettyFactory.Mode mode, boolean compressionEnabled) - { - assert messagingVersion > 0; - this.messagingVersion = messagingVersion; - this.mode = mode; - this.compressionEnabled = compressionEnabled; - } - - @VisibleForTesting - int encodeFlags() - { - int flags = 0; - if (compressionEnabled) - flags |= 1 << 2; - if (mode == NettyFactory.Mode.STREAMING) - flags |= 1 << 3; - - flags |= (messagingVersion << 8); - return flags; - } - - public ByteBuf encode(ByteBufAllocator allocator) - { - ByteBuf buffer = allocator.directBuffer(LENGTH, LENGTH); - buffer.writerIndex(0); - buffer.writeInt(MessagingService.PROTOCOL_MAGIC); - buffer.writeInt(encodeFlags()); - return buffer; - } - - static FirstHandshakeMessage maybeDecode(ByteBuf in) throws IOException - { - if (in.readableBytes() < LENGTH) - return null; - - MessagingService.validateMagic(in.readInt()); - int flags = in.readInt(); - int version = MessagingService.getBits(flags, 15, 8); - NettyFactory.Mode mode = MessagingService.getBits(flags, 3, 1) == 1 - ? NettyFactory.Mode.STREAMING - : NettyFactory.Mode.MESSAGING; - boolean compressed = MessagingService.getBits(flags, 2, 1) == 1; - return new FirstHandshakeMessage(version, mode, compressed); - } - - @Override - public boolean equals(Object other) - { - if (!(other instanceof FirstHandshakeMessage)) - return false; - - FirstHandshakeMessage that = (FirstHandshakeMessage)other; - return this.messagingVersion == that.messagingVersion - && this.mode == that.mode - && this.compressionEnabled == that.compressionEnabled; - } - - @Override - public int hashCode() - { - return Objects.hash(messagingVersion, mode, compressionEnabled); - } - - @Override - public String toString() - { - return String.format("FirstHandshakeMessage - messaging version: %d, mode: %s, compress: %b", messagingVersion, mode, compressionEnabled); - } - } - - /** - * The second message of the handshake, sent by the node receiving the {@link FirstHandshakeMessage} back to the - * connection initiator. This message contains the messaging version of the peer sending this message, - * so {@link org.apache.cassandra.net.MessagingService#current_version}. - */ - static class SecondHandshakeMessage - { - /** The messaging version sent by the receiving peer (int). */ - private static final int LENGTH = 4; - - final int messagingVersion; - - SecondHandshakeMessage(int messagingVersion) - { - this.messagingVersion = messagingVersion; - } - - public ByteBuf encode(ByteBufAllocator allocator) - { - ByteBuf buffer = allocator.directBuffer(LENGTH, LENGTH); - buffer.writerIndex(0); - buffer.writeInt(messagingVersion); - return buffer; - } - - static SecondHandshakeMessage maybeDecode(ByteBuf in) - { - return in.readableBytes() >= LENGTH ? new SecondHandshakeMessage(in.readInt()) : null; - } - - @Override - public boolean equals(Object other) - { - return other instanceof SecondHandshakeMessage - && this.messagingVersion == ((SecondHandshakeMessage) other).messagingVersion; - } - - @Override - public int hashCode() - { - return Integer.hashCode(messagingVersion); - } - - @Override - public String toString() - { - return String.format("SecondHandshakeMessage - messaging version: %d", messagingVersion); - } - } - - /** - * The third message of the handshake, sent by the connection initiator on reception of {@link SecondHandshakeMessage}. - * This message contains: - * 1) the connection initiator's messaging version (4 bytes) - {@link org.apache.cassandra.net.MessagingService#current_version}. - * This indicates the max messaging version supported by this node. - * 2) the connection initiator's broadcast address as encoded by {@link org.apache.cassandra.net.CompactEndpointSerializationHelper}. - * This can be either 5 bytes for an IPv4 address, or 17 bytes for an IPv6 one. - *

- * This message concludes the handshake protocol. After that, the connection will used either for streaming, or to - * send messages. If the connection is to be compressed, compression is enabled only after this message is sent/received. - */ - static class ThirdHandshakeMessage - { - /** - * The third message contains the version and IP address of the sending node. Because the IP can be either IPv4 or - * IPv6, this can be either 9 (4 for version + 5 for IP) or 21 (4 for version + 17 for IP) bytes. Since we can't know - * a priori if the IP address will be v4 or v6, go with the minimum required bytes and hope that if the address is - * v6, we'll have the extra 12 bytes in the packet. - */ - private static final int MIN_LENGTH = 9; - - /** - * The internode messaging version of the peer; used for serializing to a version the peer understands. - */ - final int messagingVersion; - final InetAddressAndPort address; - - ThirdHandshakeMessage(int messagingVersion, InetAddressAndPort address) - { - this.messagingVersion = messagingVersion; - this.address = address; - } - - @SuppressWarnings("resource") - public ByteBuf encode(ByteBufAllocator allocator) - { - int bufLength = Ints.checkedCast(Integer.BYTES + CompactEndpointSerializationHelper.instance.serializedSize(address, messagingVersion)); - ByteBuf buffer = allocator.directBuffer(bufLength, bufLength); - buffer.writerIndex(0); - - // the max messaging version supported by the local node (not #messagingVersion) - buffer.writeInt(MessagingService.current_version); - try - { - DataOutputPlus dop = new ByteBufDataOutputPlus(buffer); - CompactEndpointSerializationHelper.instance.serialize(address, dop, messagingVersion); - return buffer; - } - catch (IOException e) - { - // Shouldn't happen, we're serializing in memory. - throw new AssertionError(e); - } - } - - @SuppressWarnings("resource") - static ThirdHandshakeMessage maybeDecode(ByteBuf in) - { - if (in.readableBytes() < MIN_LENGTH) - return null; - - in.markReaderIndex(); - int version = in.readInt(); - DataInputPlus input = new ByteBufDataInputPlus(in); - try - { - InetAddressAndPort address = CompactEndpointSerializationHelper.instance.deserialize(input, version); - return new ThirdHandshakeMessage(version, address); - } - catch (IOException e) - { - // makes the assumption we didn't have enough bytes to deserialize an IPv6 address, - // as we only check the MIN_LENGTH of the buf. - in.resetReaderIndex(); - return null; - } - } - - @Override - public boolean equals(Object other) - { - if (!(other instanceof ThirdHandshakeMessage)) - return false; - - ThirdHandshakeMessage that = (ThirdHandshakeMessage)other; - return this.messagingVersion == that.messagingVersion - && Objects.equals(this.address, that.address); - } - - @Override - public int hashCode() - { - return Objects.hash(messagingVersion, address); - } - - @Override - public String toString() - { - return String.format("ThirdHandshakeMessage - messaging version: %d, address = %s", messagingVersion, address); - } - } -} diff --git a/src/java/org/apache/cassandra/net/async/InboundHandshakeHandler.java b/src/java/org/apache/cassandra/net/async/InboundHandshakeHandler.java deleted file mode 100644 index e66a589c8a5c..000000000000 --- a/src/java/org/apache/cassandra/net/async/InboundHandshakeHandler.java +++ /dev/null @@ -1,321 +0,0 @@ -package org.apache.cassandra.net.async; - -import java.io.IOException; -import java.net.InetAddress; -import java.net.InetSocketAddress; -import java.net.SocketAddress; -import java.util.List; -import java.util.concurrent.Future; -import java.util.concurrent.TimeUnit; -import javax.net.ssl.SSLSession; - -import com.google.common.annotations.VisibleForTesting; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import io.netty.buffer.ByteBuf; -import io.netty.channel.AdaptiveRecvByteBufAllocator; -import io.netty.channel.ChannelFutureListener; -import io.netty.channel.ChannelHandlerContext; -import io.netty.channel.ChannelPipeline; -import io.netty.handler.codec.ByteToMessageDecoder; -import io.netty.handler.ssl.SslHandler; -import org.apache.cassandra.auth.IInternodeAuthenticator; -import org.apache.cassandra.config.DatabaseDescriptor; -import org.apache.cassandra.locator.InetAddressAndPort; -import org.apache.cassandra.net.MessagingService; -import org.apache.cassandra.net.async.HandshakeProtocol.FirstHandshakeMessage; -import org.apache.cassandra.net.async.HandshakeProtocol.SecondHandshakeMessage; -import org.apache.cassandra.net.async.HandshakeProtocol.ThirdHandshakeMessage; -import org.apache.cassandra.streaming.async.StreamingInboundHandler; -import org.apache.cassandra.streaming.messages.StreamMessage; - -/** - * 'Server'-side component that negotiates the internode handshake when establishing a new connection. - * This handler will be the first in the netty channel for each incoming connection (secure socket (TLS) notwithstanding), - * and once the handshake is successful, it will configure the proper handlers (mostly {@link MessageInHandler}) - * and remove itself from the working pipeline. - */ -class InboundHandshakeHandler extends ByteToMessageDecoder -{ - private static final Logger logger = LoggerFactory.getLogger(InboundHandshakeHandler.class); - - enum State { START, AWAITING_HANDSHAKE_BEGIN, AWAIT_MESSAGING_START_RESPONSE, HANDSHAKE_COMPLETE, HANDSHAKE_FAIL } - - private State state; - - private final IInternodeAuthenticator authenticator; - - private boolean hasAuthenticated; - /** - * The peer's declared messaging version. - */ - private int version; - - /** - * Does the peer support (or want to use) compressed data? - */ - private boolean compressed; - - /** - * A future the essentially places a timeout on how long we'll wait for the peer - * to complete the next step of the handshake. - */ - private Future handshakeTimeout; - - InboundHandshakeHandler(IInternodeAuthenticator authenticator) - { - this.authenticator = authenticator; - state = State.START; - } - - @Override - protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) - { - try - { - if (!hasAuthenticated) - { - logSecureSocketDetails(ctx); - if (!handleAuthenticate(ctx.channel().remoteAddress(), ctx)) - return; - } - - switch (state) - { - case START: - state = handleStart(ctx, in); - break; - case AWAIT_MESSAGING_START_RESPONSE: - state = handleMessagingStartResponse(ctx, in); - break; - case HANDSHAKE_FAIL: - throw new IllegalStateException("channel should be closed after determining the handshake failed with peer: " + ctx.channel().remoteAddress()); - default: - logger.error("unhandled state: {}", state); - state = State.HANDSHAKE_FAIL; - ctx.close(); - } - } - catch (Exception e) - { - logger.error("unexpected error while negotiating internode messaging handshake", e); - state = State.HANDSHAKE_FAIL; - ctx.close(); - } - } - - /** - * Ensure the peer is allowed to connect to this node. - */ - @VisibleForTesting - boolean handleAuthenticate(SocketAddress socketAddress, ChannelHandlerContext ctx) - { - // the only reason addr would not be instanceof InetSocketAddress is in unit testing, when netty's EmbeddedChannel - // uses EmbeddedSocketAddress. Normally, we'd do an instanceof for that class name, but it's marked with default visibility, - // so we can't reference it outside of it's package (and so it doesn't compile). - if (socketAddress instanceof InetSocketAddress) - { - InetSocketAddress addr = (InetSocketAddress)socketAddress; - if (!authenticator.authenticate(addr.getAddress(), addr.getPort())) - { - if (logger.isTraceEnabled()) - logger.trace("Failed to authenticate peer {}", addr); - ctx.close(); - return false; - } - } - else if (!socketAddress.getClass().getSimpleName().equals("EmbeddedSocketAddress")) - { - ctx.close(); - return false; - } - hasAuthenticated = true; - return true; - } - - /** - * If the connection is using SSL/TLS, log some details about it. - */ - private void logSecureSocketDetails(ChannelHandlerContext ctx) - { - SslHandler sslHandler = ctx.pipeline().get(SslHandler.class); - if (sslHandler != null) - { - SSLSession session = sslHandler.engine().getSession(); - logger.info("connection from peer {}, protocol = {}, cipher suite = {}", - ctx.channel().remoteAddress(), session.getProtocol(), session.getCipherSuite()); - } - } - - /** - * Handles receiving the first message in the internode messaging handshake protocol. If the sender's protocol version - * is accepted, we respond with the second message of the handshake protocol. - */ - @VisibleForTesting - State handleStart(ChannelHandlerContext ctx, ByteBuf in) throws IOException - { - FirstHandshakeMessage msg = FirstHandshakeMessage.maybeDecode(in); - if (msg == null) - return State.START; - - logger.trace("received first handshake message from peer {}, message = {}", ctx.channel().remoteAddress(), msg); - version = msg.messagingVersion; - - if (msg.mode == NettyFactory.Mode.STREAMING) - { - // streaming connections are per-session and have a fixed version. we can't do anything with a wrong-version stream connection, so drop it. - if (version != StreamMessage.CURRENT_VERSION) - { - logger.warn("Received stream using protocol version {} (my version {}). Terminating connection", version, MessagingService.current_version); - ctx.close(); - return State.HANDSHAKE_FAIL; - } - - setupStreamingPipeline(ctx, version); - return State.HANDSHAKE_COMPLETE; - } - else - { - if (version < MessagingService.VERSION_30) - { - logger.error("Unable to read obsolete message version {} from {}; The earliest version supported is 3.0.0", version, ctx.channel().remoteAddress()); - ctx.close(); - return State.HANDSHAKE_FAIL; - } - - logger.trace("Connection version {} from {}", version, ctx.channel().remoteAddress()); - compressed = msg.compressionEnabled; - - // if this version is < the MS version the other node is trying - // to connect with, the other node will disconnect - ctx.writeAndFlush(new SecondHandshakeMessage(MessagingService.current_version).encode(ctx.alloc())) - .addListener(ChannelFutureListener.CLOSE_ON_FAILURE); - - // outbound side will reconnect to change the version - if (version > MessagingService.current_version) - { - logger.info("peer wants to use a messaging version higher ({}) than what this node supports ({})", version, MessagingService.current_version); - ctx.close(); - return State.HANDSHAKE_FAIL; - } - - long timeout = TimeUnit.MILLISECONDS.toNanos(DatabaseDescriptor.getRpcTimeout()); - handshakeTimeout = ctx.executor().schedule(() -> failHandshake(ctx), timeout, TimeUnit.MILLISECONDS); - return State.AWAIT_MESSAGING_START_RESPONSE; - } - } - - private void setupStreamingPipeline(ChannelHandlerContext ctx, int protocolVersion) - { - ChannelPipeline pipeline = ctx.pipeline(); - InetSocketAddress address = (InetSocketAddress) ctx.channel().remoteAddress(); - pipeline.addLast(NettyFactory.instance.streamingGroup, "streamInbound", new StreamingInboundHandler(InetAddressAndPort.getByAddressOverrideDefaults(address.getAddress(), address.getPort()), protocolVersion, null)); - pipeline.remove(this); - - // pass a custom recv ByteBuf allocator to the channel. the default recv ByteBuf size is 1k, but in streaming we're - // dealing with large bulk blocks of data, let's default to larger sizes - ctx.channel().config().setRecvByteBufAllocator(new AdaptiveRecvByteBufAllocator(1 << 8, 1 << 13, 1 << 16)); - } - - /** - * Handles the third (and last) message in the internode messaging handshake protocol. Grabs the protocol version and - * IP addr the peer wants to use. - */ - @VisibleForTesting - State handleMessagingStartResponse(ChannelHandlerContext ctx, ByteBuf in) throws IOException - { - ThirdHandshakeMessage msg = ThirdHandshakeMessage.maybeDecode(in); - if (msg == null) - return State.AWAIT_MESSAGING_START_RESPONSE; - - logger.trace("received third handshake message from peer {}, message = {}", ctx.channel().remoteAddress(), msg); - if (handshakeTimeout != null) - { - handshakeTimeout.cancel(false); - handshakeTimeout = null; - } - - int maxVersion = msg.messagingVersion; - if (maxVersion > MessagingService.current_version) - { - logger.error("peer wants to use a messaging version higher ({}) than what this node supports ({})", maxVersion, MessagingService.current_version); - ctx.close(); - return State.HANDSHAKE_FAIL; - } - - // record the (true) version of the endpoint - InetAddressAndPort from = msg.address; - MessagingService.instance().setVersion(from, maxVersion); - if (logger.isTraceEnabled()) - logger.trace("Set version for {} to {} (will use {})", from, maxVersion, MessagingService.instance().getVersion(from)); - - setupMessagingPipeline(ctx.pipeline(), from, compressed, version); - return State.HANDSHAKE_COMPLETE; - } - - @VisibleForTesting - void setupMessagingPipeline(ChannelPipeline pipeline, InetAddressAndPort peer, boolean compressed, int messagingVersion) - { - if (compressed) - pipeline.addLast(NettyFactory.INBOUND_COMPRESSOR_HANDLER_NAME, NettyFactory.createLz4Decoder(messagingVersion)); - - BaseMessageInHandler messageInHandler = messagingVersion >= MessagingService.VERSION_40 - ? new MessageInHandler(peer, messagingVersion) - : new MessageInHandlerPre40(peer, messagingVersion); - - pipeline.addLast("messageInHandler", messageInHandler); - pipeline.remove(this); - } - - @VisibleForTesting - void failHandshake(ChannelHandlerContext ctx) - { - // we're not really racing on the handshakeTimeout as we're in the event loop, - // but, hey, defensive programming is beautiful thing! - if (state == State.HANDSHAKE_COMPLETE || (handshakeTimeout != null && handshakeTimeout.isCancelled())) - return; - - state = State.HANDSHAKE_FAIL; - ctx.close(); - - if (handshakeTimeout != null) - { - handshakeTimeout.cancel(false); - handshakeTimeout = null; - } - } - - @Override - public void channelInactive(ChannelHandlerContext ctx) - { - logger.trace("Failed to properly handshake with peer {}. Closing the channel.", ctx.channel().remoteAddress()); - failHandshake(ctx); - ctx.fireChannelInactive(); - } - - @Override - public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) - { - logger.error("Failed to properly handshake with peer {}. Closing the channel.", ctx.channel().remoteAddress(), cause); - failHandshake(ctx); - } - - @VisibleForTesting - public State getState() - { - return state; - } - - @VisibleForTesting - public void setState(State nextState) - { - state = nextState; - } - - @VisibleForTesting - void setHandshakeTimeout(Future timeout) - { - handshakeTimeout = timeout; - } -} diff --git a/src/java/org/apache/cassandra/net/async/MessageInHandler.java b/src/java/org/apache/cassandra/net/async/MessageInHandler.java deleted file mode 100644 index dafa9933331b..000000000000 --- a/src/java/org/apache/cassandra/net/async/MessageInHandler.java +++ /dev/null @@ -1,211 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.cassandra.net.async; - -import java.io.IOException; -import java.util.Collections; -import java.util.EnumMap; -import java.util.List; -import java.util.Map; -import java.util.function.BiConsumer; -import java.util.function.BooleanSupplier; - -import com.google.common.primitives.Ints; -import org.apache.cassandra.io.IVersionedSerializer; -import org.apache.cassandra.io.util.DataInputPlus; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import io.netty.buffer.ByteBuf; -import io.netty.channel.ChannelHandlerContext; -import org.apache.cassandra.io.util.DataInputBuffer; -import org.apache.cassandra.locator.InetAddressAndPort; -import org.apache.cassandra.net.MessageIn; -import org.apache.cassandra.net.MessagingService; -import org.apache.cassandra.net.ParameterType; -import org.apache.cassandra.utils.vint.VIntCoding; - -/** - * Parses incoming messages as per the 4.0 internode messaging protocol. - */ -public class MessageInHandler extends BaseMessageInHandler -{ - public static final Logger logger = LoggerFactory.getLogger(MessageInHandler.class); - - private MessageHeader messageHeader; - - MessageInHandler(InetAddressAndPort peer, int messagingVersion) - { - this (peer, messagingVersion, MESSAGING_SERVICE_CONSUMER); - } - - public MessageInHandler(InetAddressAndPort peer, int messagingVersion, BiConsumer messageConsumer) - { - super(peer, messagingVersion, messageConsumer); - - assert messagingVersion >= MessagingService.VERSION_40 : String.format("wrong messaging version for this handler: got %d, but expect %d or higher", - messagingVersion, MessagingService.VERSION_40); - state = State.READ_FIRST_CHUNK; - } - - /** - * For each new message coming in, builds up a {@link MessageHeader} instance incrementally. This method - * attempts to deserialize as much header information as it can out of the incoming {@link ByteBuf}, and - * maintains a trivial state machine to remember progress across invocations. - */ - @SuppressWarnings("resource") - public void handleDecode(ChannelHandlerContext ctx, ByteBuf in, List out) throws Exception - { - ByteBufDataInputPlus inputPlus = new ByteBufDataInputPlus(in); - while (true) - { - switch (state) - { - case READ_FIRST_CHUNK: - MessageHeader header = readFirstChunk(in); - if (header == null) - return; - header.from = peer; - messageHeader = header; - state = State.READ_VERB; - // fall-through - case READ_VERB: - if (in.readableBytes() < VERB_LENGTH) - return; - messageHeader.verb = MessagingService.Verb.fromId(in.readInt()); - state = State.READ_PARAMETERS_SIZE; - // fall-through - case READ_PARAMETERS_SIZE: - long length = VIntCoding.readUnsignedVInt(in); - if (length < 0) - return; - messageHeader.parameterLength = (int) length; - messageHeader.parameters = messageHeader.parameterLength == 0 ? Collections.emptyMap() : new EnumMap<>(ParameterType.class); - state = State.READ_PARAMETERS_DATA; - // fall-through - case READ_PARAMETERS_DATA: - if (messageHeader.parameterLength > 0) - { - if (in.readableBytes() < messageHeader.parameterLength) - return; - readParameters(in, inputPlus, messagingVersion, messageHeader.parameterLength, messageHeader.parameters); - } - state = State.READ_PAYLOAD_SIZE; - // fall-through - case READ_PAYLOAD_SIZE: - length = VIntCoding.readUnsignedVInt(in); - if (length < 0) - return; - messageHeader.payloadSize = (int) length; - state = State.READ_PAYLOAD; - // fall-through - case READ_PAYLOAD: - if (in.readableBytes() < messageHeader.payloadSize) - return; - - // TODO consider deserializing the message not on the event loop - MessageIn messageIn = MessageIn.read(inputPlus, messagingVersion, - messageHeader.messageId, messageHeader.constructionTime, messageHeader.from, - messageHeader.payloadSize, messageHeader.verb, messageHeader.parameters); - - if (messageIn != null) - messageConsumer.accept(messageIn, messageHeader.messageId); - - state = State.READ_FIRST_CHUNK; - messageHeader = null; - break; - default: - throw new IllegalStateException("unknown/unhandled state: " + state); - } - } - } - - private static void readParameters(ByteBuf buf, DataInputPlus in, int messagingVersion, int parameterLength, Map parameters) throws IOException - { - // makes the assumption we have all the bytes required to read the headers - final int endIndex = buf.readerIndex() + parameterLength; - while (buf.readerIndex() < endIndex) - { - String key = in.readUTF(); - ParameterType parameterType = ParameterType.byName.get(key); - long valueLength = in.readUnsignedVInt(); - byte[] value = new byte[Ints.checkedCast(valueLength)]; - in.readFully(value); - try (DataInputBuffer buffer = new DataInputBuffer(value)) - { - parameters.put(parameterType, parameterType.serializer.deserialize(buffer, messagingVersion)); - } - } - } - - private static void readParameters(BooleanSupplier isDone, DataInputPlus in, int messagingVersion, Map parameters) throws IOException - { - // makes the assumption we have all the bytes required to read the headers - while (!isDone.getAsBoolean()) - { - String key = in.readUTF(); - ParameterType parameterType = ParameterType.byName.get(key); - in.readUnsignedVInt(); - parameters.put(parameterType, parameterType.serializer.deserialize(in, messagingVersion)); - } - } - - public static MessageIn deserialize(DataInputPlus in, int id, int version, InetAddressAndPort from) throws IOException - { - if (version >= MessagingService.VERSION_40) - return deserialize40(in, id, version, from); - else - return MessageInHandlerPre40.deserializePre40(in, id, version, from); - } - - private static MessageIn deserialize40(DataInputPlus in, int id, int version, InetAddressAndPort from) throws IOException - { - MessagingService.Verb verb = MessagingService.Verb.fromId(in.readInt()); - - Map parameters = Collections.emptyMap(); - int parameterLength = (int) in.readUnsignedVInt(); - if (parameterLength != 0) - { - parameters = new EnumMap<>(ParameterType.class); - byte[] bytes = new byte[parameterLength]; - in.readFully(bytes); - try (DataInputBuffer buffer = new DataInputBuffer(bytes)) - { - readParameters(() -> buffer.available() == 0, buffer, version, parameters); - } - } - - Object payload = null; - int payloadSize = (int) in.readUnsignedVInt(); - if (payloadSize > 0) - { - IVersionedSerializer serializer = MessagingService.getVerbSerializer(verb, id); - if (serializer == null) in.skipBytesFully(payloadSize); - else payload = serializer.deserialize(in, version); - } - - return new MessageIn<>(from, payload, parameters, verb, version, System.nanoTime()); - } - - @Override - MessageHeader getMessageHeader() - { - return messageHeader; - } -} diff --git a/src/java/org/apache/cassandra/net/async/MessageInHandlerPre40.java b/src/java/org/apache/cassandra/net/async/MessageInHandlerPre40.java deleted file mode 100644 index 6eeeea7dd2e2..000000000000 --- a/src/java/org/apache/cassandra/net/async/MessageInHandlerPre40.java +++ /dev/null @@ -1,269 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.cassandra.net.async; - -import java.io.DataInputStream; -import java.io.IOException; -import java.util.Collections; -import java.util.EnumMap; -import java.util.List; -import java.util.Map; -import java.util.function.BiConsumer; -import java.util.function.BooleanSupplier; - -import com.google.common.annotations.VisibleForTesting; -import org.apache.cassandra.io.IVersionedSerializer; -import org.apache.cassandra.io.util.DataInputPlus; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import io.netty.buffer.ByteBuf; -import io.netty.channel.ChannelHandlerContext; -import org.apache.cassandra.io.util.DataInputBuffer; -import org.apache.cassandra.locator.InetAddressAndPort; -import org.apache.cassandra.net.CompactEndpointSerializationHelper; -import org.apache.cassandra.net.MessageIn; -import org.apache.cassandra.net.MessagingService; -import org.apache.cassandra.net.ParameterType; - -/** - * Parses incoming messages as per the pre-4.0 internode messaging protocol. - */ -public class MessageInHandlerPre40 extends BaseMessageInHandler -{ - public static final Logger logger = LoggerFactory.getLogger(MessageInHandlerPre40.class); - - static final int PARAMETERS_SIZE_LENGTH = Integer.BYTES; - static final int PARAMETERS_VALUE_SIZE_LENGTH = Integer.BYTES; - static final int PAYLOAD_SIZE_LENGTH = Integer.BYTES; - - private MessageHeader messageHeader; - - MessageInHandlerPre40(InetAddressAndPort peer, int messagingVersion) - { - this (peer, messagingVersion, MESSAGING_SERVICE_CONSUMER); - } - - public MessageInHandlerPre40(InetAddressAndPort peer, int messagingVersion, BiConsumer messageConsumer) - { - super(peer, messagingVersion, messageConsumer); - - assert messagingVersion < MessagingService.VERSION_40 : String.format("wrong messaging version for this handler: got %d, but expect lower than %d", - messagingVersion, MessagingService.VERSION_40); - state = State.READ_FIRST_CHUNK; - } - - /** - * For each new message coming in, builds up a {@link MessageHeader} instance incrementally. This method - * attempts to deserialize as much header information as it can out of the incoming {@link ByteBuf}, and - * maintains a trivial state machine to remember progress across invocations. - */ - @SuppressWarnings("resource") - public void handleDecode(ChannelHandlerContext ctx, ByteBuf in, List out) throws Exception - { - ByteBufDataInputPlus inputPlus = new ByteBufDataInputPlus(in); - while (true) - { - switch (state) - { - case READ_FIRST_CHUNK: - MessageHeader header = readFirstChunk(in); - if (header == null) - return; - messageHeader = header; - state = State.READ_IP_ADDRESS; - // fall-through - case READ_IP_ADDRESS: - // unfortunately, this assumes knowledge of how CompactEndpointSerializationHelper serializes data (the first byte is the size). - // first, check that we can actually read the size byte, then check if we can read that number of bytes. - // the "+ 1" is to make sure we have the size byte in addition to the serialized IP addr count of bytes in the buffer. - int readableBytes = in.readableBytes(); - if (readableBytes < 1 || readableBytes < in.getByte(in.readerIndex()) + 1) - return; - messageHeader.from = CompactEndpointSerializationHelper.instance.deserialize(inputPlus, messagingVersion); - state = State.READ_VERB; - // fall-through - case READ_VERB: - if (in.readableBytes() < VERB_LENGTH) - return; - messageHeader.verb = MessagingService.Verb.fromId(in.readInt()); - state = State.READ_PARAMETERS_SIZE; - // fall-through - case READ_PARAMETERS_SIZE: - if (in.readableBytes() < PARAMETERS_SIZE_LENGTH) - return; - messageHeader.parameterLength = in.readInt(); - messageHeader.parameters = messageHeader.parameterLength == 0 ? Collections.emptyMap() : new EnumMap<>(ParameterType.class); - state = State.READ_PARAMETERS_DATA; - // fall-through - case READ_PARAMETERS_DATA: - if (messageHeader.parameterLength > 0) - { - if (!readParameters(in, inputPlus, messageHeader.parameterLength, messageHeader.parameters)) - return; - } - state = State.READ_PAYLOAD_SIZE; - // fall-through - case READ_PAYLOAD_SIZE: - if (in.readableBytes() < PAYLOAD_SIZE_LENGTH) - return; - messageHeader.payloadSize = in.readInt(); - state = State.READ_PAYLOAD; - // fall-through - case READ_PAYLOAD: - if (in.readableBytes() < messageHeader.payloadSize) - return; - - // TODO consider deserailizing the messge not on the event loop - MessageIn messageIn = MessageIn.read(inputPlus, messagingVersion, - messageHeader.messageId, messageHeader.constructionTime, messageHeader.from, - messageHeader.payloadSize, messageHeader.verb, messageHeader.parameters); - - if (messageIn != null) - messageConsumer.accept(messageIn, messageHeader.messageId); - - state = State.READ_FIRST_CHUNK; - messageHeader = null; - break; - default: - throw new IllegalStateException("unknown/unhandled state: " + state); - } - } - } - - /** - * @return true if all the parameters have been read from the {@link ByteBuf}; else, false. - */ - private boolean readParameters(ByteBuf in, ByteBufDataInputPlus inputPlus, int parameterCount, Map parameters) throws IOException - { - // makes the assumption that map.size() is a constant time function (HashMap.size() is) - while (parameters.size() < parameterCount) - { - if (!canReadNextParam(in)) - return false; - - String key = inputPlus.readUTF(); - ParameterType parameterType = ParameterType.byName.get(key); - byte[] value = new byte[inputPlus.readInt()]; - inputPlus.readFully(value); - try (DataInputBuffer buffer = new DataInputBuffer(value)) - { - parameters.put(parameterType, parameterType.serializer.deserialize(buffer, messagingVersion)); - } - } - - return true; - } - - private static boolean readParameters(DataInputPlus in, int messagingVersion, int parameterCount, Map parameters) throws IOException - { - // makes the assumption that map.size() is a constant time function (HashMap.size() is) - while (parameters.size() < parameterCount) - { - String key = in.readUTF(); - ParameterType parameterType = ParameterType.byName.get(key); - in.readInt(); - parameters.put(parameterType, parameterType.serializer.deserialize(in, messagingVersion)); - } - - return true; - } - - static MessageIn deserializePre40(DataInputPlus in, int id, int version, InetAddressAndPort from) throws IOException - { - assert from.equals(CompactEndpointSerializationHelper.instance.deserialize(in, version)); - MessagingService.Verb verb = MessagingService.Verb.fromId(in.readInt()); - - Map parameters = Collections.emptyMap(); - int parameterCount = in.readInt(); - if (parameterCount != 0) - { - parameters = new EnumMap<>(ParameterType.class); - readParameters(in, version, parameterCount, parameters); - } - - Object payload = null; - int payloadSize = in.readInt(); - if (payloadSize > 0) - { - IVersionedSerializer serializer = MessagingService.getVerbSerializer(verb, id); - if (serializer == null) in.skipBytesFully(payloadSize); - else payload = serializer.deserialize(in, version); - } - - return new MessageIn<>(from, payload, parameters, verb, version, System.nanoTime()); - } - - - - /** - * Determine if we can read the next parameter from the {@link ByteBuf}. This method will *always* set the {@code in} - * readIndex back to where it was when this method was invoked. - * - * NOTE: this function would be sooo much simpler if we included a parameters length int in the messaging format, - * instead of checking the remaining readable bytes for each field as we're parsing it. c'est la vie ... - */ - @VisibleForTesting - static boolean canReadNextParam(ByteBuf in) - { - in.markReaderIndex(); - // capture the readableBytes value here to avoid all the virtual function calls. - // subtract 6 as we know we'll be reading a short and an int (for the utf and value lengths). - final int minimumBytesRequired = 6; - int readableBytes = in.readableBytes() - minimumBytesRequired; - if (readableBytes < 0) - return false; - - // this is a tad invasive, but since we know the UTF string is prefaced with a 2-byte length, - // read that to make sure we have enough bytes to read the string itself. - short strLen = in.readShort(); - // check if we can read that many bytes for the UTF - if (strLen > readableBytes) - { - in.resetReaderIndex(); - return false; - } - in.skipBytes(strLen); - readableBytes -= strLen; - - // check if we can read the value length - if (readableBytes < PARAMETERS_VALUE_SIZE_LENGTH) - { - in.resetReaderIndex(); - return false; - } - int valueLength = in.readInt(); - // check if we read that many bytes for the value - if (valueLength > readableBytes) - { - in.resetReaderIndex(); - return false; - } - - in.resetReaderIndex(); - return true; - } - - - @Override - MessageHeader getMessageHeader() - { - return messageHeader; - } -} diff --git a/src/java/org/apache/cassandra/net/async/MessageOutHandler.java b/src/java/org/apache/cassandra/net/async/MessageOutHandler.java deleted file mode 100644 index f1647ab03809..000000000000 --- a/src/java/org/apache/cassandra/net/async/MessageOutHandler.java +++ /dev/null @@ -1,324 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.cassandra.net.async; - -import java.io.IOException; -import java.nio.ByteBuffer; -import java.util.UUID; -import java.util.concurrent.TimeUnit; -import java.util.function.Supplier; - -import com.google.common.annotations.VisibleForTesting; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import io.netty.buffer.ByteBuf; -import io.netty.channel.ChannelDuplexHandler; -import io.netty.channel.ChannelHandler; -import io.netty.channel.ChannelHandlerContext; -import io.netty.channel.ChannelOutboundBuffer; -import io.netty.channel.ChannelPromise; -import io.netty.handler.codec.UnsupportedMessageTypeException; -import io.netty.handler.timeout.IdleState; -import io.netty.handler.timeout.IdleStateEvent; - -import org.apache.cassandra.io.util.DataOutputPlus; -import org.apache.cassandra.net.MessagingService; -import org.apache.cassandra.net.ParameterType; -import org.apache.cassandra.tracing.TraceState; -import org.apache.cassandra.tracing.Tracing; -import org.apache.cassandra.utils.NanoTimeToCurrentTimeMillis; -import org.apache.cassandra.utils.NoSpamLogger; -import org.apache.cassandra.utils.UUIDGen; - -import static org.apache.cassandra.config.Config.PROPERTY_PREFIX; - -/** - * A Netty {@link ChannelHandler} for serializing outbound messages. - *

- * On top of transforming a {@link QueuedMessage} into bytes, this handler also feeds back progress to the linked - * {@link ChannelWriter} so that the latter can take decision on when data should be flushed (with and without coalescing). - * See the javadoc on {@link ChannelWriter} for more details about the callbacks as well as message timeouts. - *

- * Note: this class derives from {@link ChannelDuplexHandler} so we can intercept calls to - * {@link #userEventTriggered(ChannelHandlerContext, Object)} and {@link #channelWritabilityChanged(ChannelHandlerContext)}. - */ -class MessageOutHandler extends ChannelDuplexHandler -{ - private static final Logger logger = LoggerFactory.getLogger(MessageOutHandler.class); - private static final NoSpamLogger errorLogger = NoSpamLogger.getLogger(logger, 1, TimeUnit.SECONDS); - - /** - * The default size threshold for deciding when to auto-flush the channel. - */ - private static final int DEFAULT_AUTO_FLUSH_THRESHOLD = 1 << 16; - - // reatining the pre 4.0 property name for backward compatibility. - private static final String AUTO_FLUSH_PROPERTY = PROPERTY_PREFIX + "otc_buffer_size"; - static final int AUTO_FLUSH_THRESHOLD = Integer.getInteger(AUTO_FLUSH_PROPERTY, DEFAULT_AUTO_FLUSH_THRESHOLD); - - /** - * The amount of prefix data, in bytes, before the serialized message. - */ - private static final int MESSAGE_PREFIX_SIZE = 12; - - private final OutboundConnectionIdentifier connectionId; - - /** - * The version of the messaging protocol we're communicating at. - */ - private final int targetMessagingVersion; - - /** - * The minumum size at which we'll automatically flush the channel. - */ - private final int flushSizeThreshold; - - private final ChannelWriter channelWriter; - - private final Supplier backlogSupplier; - - MessageOutHandler(OutboundConnectionIdentifier connectionId, int targetMessagingVersion, ChannelWriter channelWriter, Supplier backlogSupplier) - { - this (connectionId, targetMessagingVersion, channelWriter, backlogSupplier, AUTO_FLUSH_THRESHOLD); - } - - MessageOutHandler(OutboundConnectionIdentifier connectionId, int targetMessagingVersion, ChannelWriter channelWriter, Supplier backlogSupplier, int flushThreshold) - { - this.connectionId = connectionId; - this.targetMessagingVersion = targetMessagingVersion; - this.channelWriter = channelWriter; - this.flushSizeThreshold = flushThreshold; - this.backlogSupplier = backlogSupplier; - } - - @Override - public void write(ChannelHandlerContext ctx, Object o, ChannelPromise promise) - { - // this is a temporary fix until https://github.com/netty/netty/pull/6867 is released (probably netty 4.1.13). - // TL;DR a closed channel can still process messages in the pipeline that were queued before the close. - // the channel handlers are removed from the channel potentially saync from the close operation. - if (!ctx.channel().isOpen()) - { - logger.debug("attempting to process a message in the pipeline, but channel {} is closed", ctx.channel().id()); - return; - } - - ByteBuf out = null; - try - { - if (!isMessageValid(o, promise)) - return; - - QueuedMessage msg = (QueuedMessage) o; - - // frame size includes the magic and and other values *before* the actual serialized message. - // note: don't even bother to check the compressed size (if compression is enabled for the channel), - // cuz if it's this large already, we're probably screwed anyway - long currentFrameSize = MESSAGE_PREFIX_SIZE + msg.message.serializedSize(targetMessagingVersion); - if (currentFrameSize > Integer.MAX_VALUE || currentFrameSize < 0) - { - promise.tryFailure(new IllegalStateException(String.format("%s illegal frame size: %d, ignoring message", connectionId, currentFrameSize))); - return; - } - - out = ctx.alloc().ioBuffer((int)currentFrameSize); - - captureTracingInfo(msg); - serializeMessage(msg, out); - ctx.write(out, promise); - - // check to see if we should flush based on buffered size - ChannelOutboundBuffer outboundBuffer = ctx.channel().unsafe().outboundBuffer(); - if (outboundBuffer != null && outboundBuffer.totalPendingWriteBytes() >= flushSizeThreshold) - ctx.flush(); - } - catch(Exception e) - { - if (out != null && out.refCnt() > 0) - out.release(out.refCnt()); - exceptionCaught(ctx, e); - promise.tryFailure(e); - } - finally - { - // Make sure we signal the outChanel even in case of errors. - channelWriter.onMessageProcessed(ctx); - } - } - - /** - * Test to see if the message passed in is a {@link QueuedMessage} and if it has timed out or not. If the checks fail, - * this method has the side effect of modifying the {@link ChannelPromise}. - */ - boolean isMessageValid(Object o, ChannelPromise promise) - { - // optimize for the common case - if (o instanceof QueuedMessage) - { - if (!((QueuedMessage)o).isTimedOut()) - { - return true; - } - else - { - promise.tryFailure(ExpiredException.INSTANCE); - } - } - else - { - promise.tryFailure(new UnsupportedMessageTypeException(connectionId + - " msg must be an instance of " + QueuedMessage.class.getSimpleName())); - } - return false; - } - - /** - * Record any tracing data, if enabled on this message. - */ - @VisibleForTesting - void captureTracingInfo(QueuedMessage msg) - { - try - { - UUID sessionId = (UUID)msg.message.getParameter(ParameterType.TRACE_SESSION); - if (sessionId != null) - { - TraceState state = Tracing.instance.get(sessionId); - String message = String.format("Sending %s message to %s, size = %d bytes", - msg.message.verb, connectionId.connectionAddress(), - msg.message.serializedSize(targetMessagingVersion) + MESSAGE_PREFIX_SIZE); - // session may have already finished; see CASSANDRA-5668 - if (state == null) - { - Tracing.TraceType traceType = (Tracing.TraceType)msg.message.getParameter(ParameterType.TRACE_TYPE); - traceType = traceType == null ? Tracing.TraceType.QUERY : traceType; - Tracing.instance.trace(ByteBuffer.wrap(UUIDGen.decompose(sessionId)), message, traceType.getTTL()); - } - else - { - state.trace(message); - if (msg.message.verb == MessagingService.Verb.REQUEST_RESPONSE) - Tracing.instance.doneWithNonLocalSession(state); - } - } - } - catch (Exception e) - { - logger.warn("{} failed to capture the tracing info for an outbound message, ignoring", connectionId, e); - } - } - - private void serializeMessage(QueuedMessage msg, ByteBuf out) throws IOException - { - out.writeInt(MessagingService.PROTOCOL_MAGIC); - out.writeInt(msg.id); - - // int cast cuts off the high-order half of the timestamp, which we can assume remains - // the same between now and when the recipient reconstructs it. - out.writeInt((int) NanoTimeToCurrentTimeMillis.convert(msg.timestampNanos)); - @SuppressWarnings("resource") - DataOutputPlus outStream = new ByteBufDataOutputPlus(out); - msg.message.serialize(outStream, targetMessagingVersion); - - // next few lines are for debugging ... massively helpful!! - // if we allocated too much buffer for this message, we'll log here. - // if we allocated to little buffer space, we would have hit an exception when trying to write more bytes to it - if (out.isWritable()) - errorLogger.error("{} reported message size {}, actual message size {}, msg {}", - connectionId, out.capacity(), out.writerIndex(), msg.message); - } - - @Override - public void flush(ChannelHandlerContext ctx) - { - channelWriter.onTriggeredFlush(ctx); - } - - - /** - * {@inheritDoc} - * - * When the channel becomes writable (assuming it was previously unwritable), try to eat through any backlogged messages - * {@link #backlogSupplier}. As we're on the event loop when this is invoked, no one else can fill up the netty - * {@link ChannelOutboundBuffer}, so we should be able to make decent progress chewing through the backlog - * (assuming not large messages). Any messages messages written from {@link OutboundMessagingConnection} threads won't - * be processed immediately; they'll be queued up as tasks, and once this function return, those messages can begin - * to be consumed. - *

- * Note: this is invoked on the netty event loop. - */ - @Override - public void channelWritabilityChanged(ChannelHandlerContext ctx) - { - if (!ctx.channel().isWritable()) - return; - - // guarantee at least a minimal amount of progress (one messge from the backlog) by using a do-while loop. - do - { - QueuedMessage msg = backlogSupplier.get(); - if (msg == null || !channelWriter.write(msg, false)) - break; - } while (ctx.channel().isWritable()); - } - - /** - * {@inheritDoc} - * - * If we get an {@link IdleStateEvent} for the write path, we want to close the channel as we can't make progress. - * That assumes, of course, that there's any outstanding bytes in the channel to write. We don't necesarrily care - * about idleness (for example, gossip channels will be idle most of the time), but instead our concern is - * the ability to make progress when there's work to be done. - *

- * Note: this is invoked on the netty event loop. - */ - @Override - public void userEventTriggered(ChannelHandlerContext ctx, Object evt) - { - if (evt instanceof IdleStateEvent && ((IdleStateEvent)evt).state() == IdleState.WRITER_IDLE) - { - ChannelOutboundBuffer cob = ctx.channel().unsafe().outboundBuffer(); - if (cob != null && cob.totalPendingWriteBytes() > 0) - { - ctx.channel().attr(ChannelWriter.PURGE_MESSAGES_CHANNEL_ATTR) - .compareAndSet(Boolean.FALSE, Boolean.TRUE); - ctx.close(); - } - } - } - - @Override - public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) - { - if (cause instanceof IOException) - logger.trace("{} io error", connectionId, cause); - else - logger.warn("{} error", connectionId, cause); - - ctx.close(); - } - - @Override - public void close(ChannelHandlerContext ctx, ChannelPromise promise) - { - ctx.flush(); - ctx.close(promise); - } -} diff --git a/src/java/org/apache/cassandra/net/async/MessageResult.java b/src/java/org/apache/cassandra/net/async/MessageResult.java deleted file mode 100644 index b0dc4dce1c66..000000000000 --- a/src/java/org/apache/cassandra/net/async/MessageResult.java +++ /dev/null @@ -1,51 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.cassandra.net.async; - -import io.netty.util.concurrent.Future; - -/** - * A simple, reusable struct that holds the unprocessed result of sending a message via netty. This object is intended - * to be reusable to avoid creating a bunch of garbage (just for processing the results of sending a message). - * - * The intended use is to be a member field in a class, like {@link ChannelWriter}, repopulated on each message result, - * and then immediately cleared (via {@link #clearAll()}) when done. - */ -public class MessageResult -{ - ChannelWriter writer; - QueuedMessage msg; - Future future; - boolean allowReconnect; - - void setAll(ChannelWriter writer, QueuedMessage msg, Future future, boolean allowReconnect) - { - this.writer = writer; - this.msg = msg; - this.future = future; - this.allowReconnect = allowReconnect; - } - - void clearAll() - { - this.writer = null; - this.msg = null; - this.future = null; - } -} diff --git a/src/java/org/apache/cassandra/net/async/NettyFactory.java b/src/java/org/apache/cassandra/net/async/NettyFactory.java deleted file mode 100644 index 346a0672fe6c..000000000000 --- a/src/java/org/apache/cassandra/net/async/NettyFactory.java +++ /dev/null @@ -1,418 +0,0 @@ -package org.apache.cassandra.net.async; - -import java.net.InetAddress; -import java.net.InetSocketAddress; -import java.util.concurrent.TimeUnit; -import java.util.zip.Checksum; - -import javax.annotation.Nullable; -import javax.net.ssl.SSLEngine; -import javax.net.ssl.SSLParameters; - -import com.google.common.annotations.VisibleForTesting; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import io.netty.bootstrap.Bootstrap; -import io.netty.bootstrap.ServerBootstrap; -import io.netty.channel.Channel; -import io.netty.channel.ChannelFuture; -import io.netty.channel.ChannelInitializer; -import io.netty.channel.ChannelOption; -import io.netty.channel.ChannelPipeline; -import io.netty.channel.EventLoopGroup; -import io.netty.channel.ServerChannel; -import io.netty.channel.epoll.EpollChannelOption; -import io.netty.channel.epoll.EpollEventLoopGroup; -import io.netty.channel.epoll.EpollServerSocketChannel; -import io.netty.channel.epoll.EpollSocketChannel; -import io.netty.channel.group.ChannelGroup; -import io.netty.channel.nio.NioEventLoopGroup; -import io.netty.channel.socket.SocketChannel; -import io.netty.channel.socket.nio.NioServerSocketChannel; -import io.netty.channel.socket.nio.NioSocketChannel; -import io.netty.handler.codec.compression.Lz4FrameDecoder; -import io.netty.handler.codec.compression.Lz4FrameEncoder; -import io.netty.handler.logging.LogLevel; -import io.netty.handler.logging.LoggingHandler; -import io.netty.handler.ssl.OpenSsl; -import io.netty.handler.ssl.SslContext; -import io.netty.handler.ssl.SslHandler; -import io.netty.util.concurrent.DefaultEventExecutor; -import io.netty.util.concurrent.DefaultThreadFactory; -import io.netty.util.concurrent.EventExecutor; -import io.netty.util.internal.logging.InternalLoggerFactory; -import io.netty.util.internal.logging.Slf4JLoggerFactory; - -import net.jpountz.lz4.LZ4Factory; -import net.jpountz.xxhash.XXHashFactory; -import org.apache.cassandra.auth.IInternodeAuthenticator; -import org.apache.cassandra.config.DatabaseDescriptor; -import org.apache.cassandra.config.EncryptionOptions.ServerEncryptionOptions; -import org.apache.cassandra.exceptions.ConfigurationException; -import org.apache.cassandra.locator.InetAddressAndPort; -import org.apache.cassandra.net.MessagingService; -import org.apache.cassandra.security.SSLFactory; -import org.apache.cassandra.service.NativeTransportService; -import org.apache.cassandra.utils.ChecksumType; -import org.apache.cassandra.utils.CoalescingStrategies; -import org.apache.cassandra.utils.FBUtilities; - -/** - * A factory for building Netty {@link Channel}s. Channels here are setup with a pipeline to participate - * in the internode protocol handshake, either the inbound or outbound side as per the method invoked. - */ -public final class NettyFactory -{ - private static final Logger logger = LoggerFactory.getLogger(NettyFactory.class); - - /** - * The block size for use with netty's lz4 code. - */ - private static final int COMPRESSION_BLOCK_SIZE = 1 << 16; - - private static final int LZ4_HASH_SEED = 0x9747b28c; - - public enum Mode { MESSAGING, STREAMING } - - static final String SSL_CHANNEL_HANDLER_NAME = "ssl"; - private static final String OPTIONAL_SSL_CHANNEL_HANDLER_NAME = "optionalSsl"; - static final String INBOUND_COMPRESSOR_HANDLER_NAME = "inboundCompressor"; - static final String OUTBOUND_COMPRESSOR_HANDLER_NAME = "outboundCompressor"; - private static final String HANDSHAKE_HANDLER_NAME = "handshakeHandler"; - public static final String INBOUND_STREAM_HANDLER_NAME = "inboundStreamHandler"; - - /** a useful addition for debugging; simply set to true to get more data in your logs */ - private static final boolean WIRETRACE = false; - static - { - if (WIRETRACE) - InternalLoggerFactory.setDefaultFactory(Slf4JLoggerFactory.INSTANCE); - } - - private static final boolean DEFAULT_USE_EPOLL = NativeTransportService.useEpoll(); - - /** - * The size of the receive queue for the outbound channels. As outbound channels do not receive data - * (outside of the internode messaging protocol's handshake), this value can be relatively small. - */ - private static final int OUTBOUND_CHANNEL_RECEIVE_BUFFER_SIZE = 1 << 10; - - /** - * The size of the send queue for the inbound channels. As inbound channels do not send data - * (outside of the internode messaging protocol's handshake), this value can be relatively small. - */ - private static final int INBOUND_CHANNEL_SEND_BUFFER_SIZE = 1 << 10; - - /** - * A factory instance that all normal, runtime code should use. Separate instances should only be used for testing. - */ - public static final NettyFactory instance = new NettyFactory(DEFAULT_USE_EPOLL); - - private final boolean useEpoll; - private final EventLoopGroup acceptGroup; - - private final EventLoopGroup inboundGroup; - private final EventLoopGroup outboundGroup; - public final EventLoopGroup streamingGroup; - - /** - * Constructor that allows modifying the {@link NettyFactory#useEpoll} for testing purposes. Otherwise, use the - * default {@link #instance}. - */ - @VisibleForTesting - NettyFactory(boolean useEpoll) - { - this.useEpoll = useEpoll; - acceptGroup = getEventLoopGroup(useEpoll, determineAcceptGroupSize(DatabaseDescriptor.getInternodeMessagingEncyptionOptions()), - "MessagingService-NettyAcceptor-Thread", false); - inboundGroup = getEventLoopGroup(useEpoll, FBUtilities.getAvailableProcessors(), "MessagingService-NettyInbound-Thread", false); - outboundGroup = getEventLoopGroup(useEpoll, FBUtilities.getAvailableProcessors(), "MessagingService-NettyOutbound-Thread", true); - streamingGroup = getEventLoopGroup(useEpoll, FBUtilities.getAvailableProcessors(), "Streaming-Netty-Thread", false); - } - - /** - * Determine the number of accept threads we need, which is based upon the number of listening sockets we will have. - * The idea is one accept thread per listening socket. - */ - public static int determineAcceptGroupSize(ServerEncryptionOptions serverEncryptionOptions) - { - int listenSocketCount = 1; - - boolean listenOnBroadcastAddr = MessagingService.shouldListenOnBroadcastAddress(); - if (listenOnBroadcastAddr) - listenSocketCount++; - - if (serverEncryptionOptions.enable_legacy_ssl_storage_port) - { - listenSocketCount++; - - if (listenOnBroadcastAddr) - listenSocketCount++; - } - - return listenSocketCount; - } - - /** - * Create an {@link EventLoopGroup}, for epoll or nio. The {@code boostIoRatio} flag passes a hint to the netty - * event loop threads to optimize comsuming all the tasks from the netty channel before checking for IO activity. - * By default, netty will process some maximum number of tasks off it's queue before it will check for activity on - * any of the open FDs, which basically amounts to checking for any incoming data. If you have a class of event loops - * that that do almost *no* inbound activity (like cassandra's outbound channels), then it behooves us to have the - * outbound netty channel consume as many tasks as it can before making the system calls to check up on the FDs, - * as we're not expecting any incoming data on those sockets, anyways. Thus, we pass the magic value {@code 100} - * to achieve the maximum consuption from the netty queue. (for implementation details, as of netty 4.1.8, - * see {@link io.netty.channel.epoll.EpollEventLoop#run()}. - */ - static EventLoopGroup getEventLoopGroup(boolean useEpoll, int threadCount, String threadNamePrefix, boolean boostIoRatio) - { - if (useEpoll) - { - logger.debug("using netty epoll event loop for pool prefix {}", threadNamePrefix); - EpollEventLoopGroup eventLoopGroup = new EpollEventLoopGroup(threadCount, new DefaultThreadFactory(threadNamePrefix, true)); - if (boostIoRatio) - eventLoopGroup.setIoRatio(100); - return eventLoopGroup; - } - - logger.debug("using netty nio event loop for pool prefix {}", threadNamePrefix); - NioEventLoopGroup eventLoopGroup = new NioEventLoopGroup(threadCount, new DefaultThreadFactory(threadNamePrefix, true)); - if (boostIoRatio) - eventLoopGroup.setIoRatio(100); - return eventLoopGroup; - } - - /** - * Create a {@link Channel} that listens on the {@code localAddr}. This method will block while trying to bind to the address, - * but it does not make a remote call. - */ - public Channel createInboundChannel(InetAddressAndPort localAddr, InboundInitializer initializer, int receiveBufferSize) throws ConfigurationException - { - String nic = FBUtilities.getNetworkInterface(localAddr.address); - logger.info("Starting Messaging Service on {} {}, encryption: {}", - localAddr, nic == null ? "" : String.format(" (%s)", nic), encryptionLogStatement(initializer.encryptionOptions)); - Class transport = useEpoll ? EpollServerSocketChannel.class : NioServerSocketChannel.class; - ServerBootstrap bootstrap = new ServerBootstrap().group(acceptGroup, inboundGroup) - .channel(transport) - .option(ChannelOption.SO_BACKLOG, 500) - .childOption(ChannelOption.SO_KEEPALIVE, true) - .childOption(ChannelOption.TCP_NODELAY, true) - .childOption(ChannelOption.SO_REUSEADDR, true) - .childOption(ChannelOption.SO_SNDBUF, INBOUND_CHANNEL_SEND_BUFFER_SIZE) - .childHandler(initializer); - - if (useEpoll) - bootstrap.childOption(EpollChannelOption.TCP_USER_TIMEOUT, DatabaseDescriptor.getInternodeTcpUserTimeoutInMS()); - - if (receiveBufferSize > 0) - bootstrap.childOption(ChannelOption.SO_RCVBUF, receiveBufferSize); - - ChannelFuture channelFuture = bootstrap.bind(new InetSocketAddress(localAddr.address, localAddr.port)); - - if (!channelFuture.awaitUninterruptibly().isSuccess()) - { - if (channelFuture.channel().isOpen()) - channelFuture.channel().close(); - - Throwable failedChannelCause = channelFuture.cause(); - - String causeString = ""; - if (failedChannelCause != null && failedChannelCause.getMessage() != null) - causeString = failedChannelCause.getMessage(); - - if (causeString.contains("in use")) - { - throw new ConfigurationException(localAddr + " is in use by another process. Change listen_address:storage_port " + - "in cassandra.yaml to values that do not conflict with other services"); - } - // looking at the jdk source, solaris/windows bind failue messages both use the phrase "cannot assign requested address". - // windows message uses "Cannot" (with a capital 'C'), and solaris (a/k/a *nux) doe not. hence we search for "annot" - else if (causeString.contains("annot assign requested address")) - { - throw new ConfigurationException("Unable to bind to address " + localAddr - + ". Set listen_address in cassandra.yaml to an interface you can bind to, e.g., your private IP address on EC2"); - } - else - { - throw new ConfigurationException("failed to bind to: " + localAddr, failedChannelCause); - } - } - - return channelFuture.channel(); - } - - /** - * Creates a new {@link SslHandler} from provided SslContext. - * @param peer enables endpoint verification for remote address when not null - */ - static SslHandler newSslHandler(Channel channel, SslContext sslContext, @Nullable InetSocketAddress peer) - { - if (peer == null) - { - return sslContext.newHandler(channel.alloc()); - } - else - { - logger.debug("Creating SSL handler for {}:{}", peer.getHostString(), peer.getPort()); - SslHandler sslHandler = sslContext.newHandler(channel.alloc(), peer.getHostString(), peer.getPort()); - SSLEngine engine = sslHandler.engine(); - SSLParameters sslParameters = engine.getSSLParameters(); - sslParameters.setEndpointIdentificationAlgorithm("HTTPS"); - engine.setSSLParameters(sslParameters); - return sslHandler; - } - } - - public static class InboundInitializer extends ChannelInitializer - { - private final IInternodeAuthenticator authenticator; - private final ServerEncryptionOptions encryptionOptions; - private final ChannelGroup channelGroup; - - public InboundInitializer(IInternodeAuthenticator authenticator, ServerEncryptionOptions encryptionOptions, ChannelGroup channelGroup) - { - this.authenticator = authenticator; - this.encryptionOptions = encryptionOptions; - this.channelGroup = channelGroup; - } - - @Override - public void initChannel(SocketChannel channel) throws Exception - { - channelGroup.add(channel); - ChannelPipeline pipeline = channel.pipeline(); - - // order of handlers: ssl -> logger -> handshakeHandler - if (encryptionOptions.enabled) - { - if (encryptionOptions.optional) - { - pipeline.addFirst(OPTIONAL_SSL_CHANNEL_HANDLER_NAME, new OptionalSslHandler(encryptionOptions)); - } - else - { - SslContext sslContext = SSLFactory.getOrCreateSslContext(encryptionOptions, true, SSLFactory.SocketType.SERVER); - InetSocketAddress peer = encryptionOptions.require_endpoint_verification ? channel.remoteAddress() : null; - SslHandler sslHandler = newSslHandler(channel, sslContext, peer); - logger.trace("creating inbound netty SslContext: context={}, engine={}", sslContext.getClass().getName(), sslHandler.engine().getClass().getName()); - pipeline.addFirst(SSL_CHANNEL_HANDLER_NAME, sslHandler); - } - } - - if (WIRETRACE) - pipeline.addLast("logger", new LoggingHandler(LogLevel.INFO)); - - channel.pipeline().addLast(HANDSHAKE_HANDLER_NAME, new InboundHandshakeHandler(authenticator)); - } - } - - private static String encryptionLogStatement(ServerEncryptionOptions options) - { - if (options == null) - return "disabled"; - - String encryptionType = OpenSsl.isAvailable() ? "openssl" : "jdk"; - return "enabled (" + encryptionType + ')'; - } - - /** - * Create the {@link Bootstrap} for connecting to a remote peer. This method does not attempt to connect to the peer, - * and thus does not block. - */ - @VisibleForTesting - public Bootstrap createOutboundBootstrap(OutboundConnectionParams params) - { - logger.debug("creating outbound bootstrap to peer {}, compression: {}, encryption: {}, coalesce: {}, protocolVersion: {}", - params.connectionId.connectionAddress(), - params.compress, encryptionLogStatement(params.encryptionOptions), - params.coalescingStrategy.isPresent() ? params.coalescingStrategy.get() : CoalescingStrategies.Strategy.DISABLED, - params.protocolVersion); - Class transport = useEpoll ? EpollSocketChannel.class : NioSocketChannel.class; - Bootstrap bootstrap = new Bootstrap().group(params.mode == Mode.MESSAGING ? outboundGroup : streamingGroup) - .channel(transport) - .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, params.tcpConnectTimeoutInMS) - .option(ChannelOption.SO_KEEPALIVE, true) - .option(ChannelOption.SO_REUSEADDR, true) - .option(ChannelOption.SO_SNDBUF, params.sendBufferSize) - .option(ChannelOption.SO_RCVBUF, OUTBOUND_CHANNEL_RECEIVE_BUFFER_SIZE) - .option(ChannelOption.TCP_NODELAY, params.tcpNoDelay) - .option(ChannelOption.WRITE_BUFFER_WATER_MARK, params.waterMark) - .handler(new OutboundInitializer(params)); - if (useEpoll) - bootstrap.option(EpollChannelOption.TCP_USER_TIMEOUT, params.tcpUserTimeoutInMS); - - InetAddressAndPort remoteAddress = params.connectionId.connectionAddress(); - bootstrap.remoteAddress(new InetSocketAddress(remoteAddress.address, remoteAddress.port)); - return bootstrap; - } - - public static class OutboundInitializer extends ChannelInitializer - { - private final OutboundConnectionParams params; - - OutboundInitializer(OutboundConnectionParams params) - { - this.params = params; - } - - /** - * {@inheritDoc} - * - * To determine if we should enable TLS, we only need to check if {@link #params#encryptionOptions} is set. - * The logic for figuring that out is is located in {@link MessagingService#getMessagingConnection(InetAddress)}; - */ - public void initChannel(SocketChannel channel) throws Exception - { - ChannelPipeline pipeline = channel.pipeline(); - - // order of handlers: ssl -> logger -> handshakeHandler - if (params.encryptionOptions != null) - { - SslContext sslContext = SSLFactory.getOrCreateSslContext(params.encryptionOptions, true, SSLFactory.SocketType.CLIENT); - // for some reason channel.remoteAddress() will return null - InetAddressAndPort address = params.connectionId.remote(); - InetSocketAddress peer = params.encryptionOptions.require_endpoint_verification ? new InetSocketAddress(address.address, address.port) : null; - SslHandler sslHandler = newSslHandler(channel, sslContext, peer); - logger.trace("creating outbound netty SslContext: context={}, engine={}", sslContext.getClass().getName(), sslHandler.engine().getClass().getName()); - pipeline.addFirst(SSL_CHANNEL_HANDLER_NAME, sslHandler); - } - - if (NettyFactory.WIRETRACE) - pipeline.addLast("logger", new LoggingHandler(LogLevel.INFO)); - - pipeline.addLast(HANDSHAKE_HANDLER_NAME, new OutboundHandshakeHandler(params)); - } - } - - public void close() throws InterruptedException - { - EventLoopGroup[] groups = new EventLoopGroup[] { acceptGroup, outboundGroup, inboundGroup, streamingGroup }; - for (EventLoopGroup group : groups) - group.shutdownGracefully(0, 2, TimeUnit.SECONDS); - for (EventLoopGroup group : groups) - group.awaitTermination(60, TimeUnit.SECONDS); - } - - static Lz4FrameEncoder createLz4Encoder(int protocolVersion) - { - return new Lz4FrameEncoder(LZ4Factory.fastestInstance(), false, COMPRESSION_BLOCK_SIZE, checksumForFrameEncoders(protocolVersion)); - } - - private static Checksum checksumForFrameEncoders(int protocolVersion) - { - if (protocolVersion >= MessagingService.current_version) - return ChecksumType.CRC32.newInstance(); - return XXHashFactory.fastestInstance().newStreamingHash32(LZ4_HASH_SEED).asChecksum(); - } - - static Lz4FrameDecoder createLz4Decoder(int protocolVersion) - { - return new Lz4FrameDecoder(LZ4Factory.fastestInstance(), checksumForFrameEncoders(protocolVersion)); - } - - public static EventExecutor executorForChannelGroups() - { - return new DefaultEventExecutor(); - } -} diff --git a/src/java/org/apache/cassandra/net/async/NonClosingDefaultFileRegion.java b/src/java/org/apache/cassandra/net/async/NonClosingDefaultFileRegion.java deleted file mode 100644 index 46f0ce162a34..000000000000 --- a/src/java/org/apache/cassandra/net/async/NonClosingDefaultFileRegion.java +++ /dev/null @@ -1,51 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.cassandra.net.async; - -import java.io.File; -import java.nio.channels.FileChannel; - -import io.netty.channel.DefaultFileRegion; - -/** - * Netty's DefaultFileRegion closes the underlying FileChannel as soon as - * the refCnt() for the region drops to zero, this is an implementation of - * the DefaultFileRegion that doesn't close the FileChannel. - * - * See {@link ByteBufDataOutputStreamPlus} for its usage. - */ -public class NonClosingDefaultFileRegion extends DefaultFileRegion -{ - - public NonClosingDefaultFileRegion(FileChannel file, long position, long count) - { - super(file, position, count); - } - - public NonClosingDefaultFileRegion(File f, long position, long count) - { - super(f, position, count); - } - - @Override - protected void deallocate() - { - // Overridden to avoid closing the file - } -} diff --git a/src/java/org/apache/cassandra/net/async/OptionalSslHandler.java b/src/java/org/apache/cassandra/net/async/OptionalSslHandler.java deleted file mode 100644 index 3fb856244c51..000000000000 --- a/src/java/org/apache/cassandra/net/async/OptionalSslHandler.java +++ /dev/null @@ -1,67 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.cassandra.net.async; - -import java.net.InetSocketAddress; -import java.util.List; - -import io.netty.buffer.ByteBuf; -import io.netty.channel.Channel; -import io.netty.channel.ChannelHandlerContext; -import io.netty.handler.codec.ByteToMessageDecoder; -import io.netty.handler.ssl.SslContext; -import io.netty.handler.ssl.SslHandler; -import org.apache.cassandra.config.EncryptionOptions.ServerEncryptionOptions; -import org.apache.cassandra.security.SSLFactory; - -public class OptionalSslHandler extends ByteToMessageDecoder -{ - private final ServerEncryptionOptions encryptionOptions; - - OptionalSslHandler(ServerEncryptionOptions encryptionOptions) - { - this.encryptionOptions = encryptionOptions; - } - - protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) throws Exception - { - if (in.readableBytes() < 5) - { - // To detect if SSL must be used we need to have at least 5 bytes, so return here and try again - // once more bytes a ready. - return; - } - - if (SslHandler.isEncrypted(in)) - { - // Connection uses SSL/TLS, replace the detection handler with a SslHandler and so use encryption. - SslContext sslContext = SSLFactory.getOrCreateSslContext(encryptionOptions, true, SSLFactory.SocketType.SERVER); - Channel channel = ctx.channel(); - InetSocketAddress peer = encryptionOptions.require_endpoint_verification ? (InetSocketAddress) channel.remoteAddress() : null; - SslHandler sslHandler = NettyFactory.newSslHandler(channel, sslContext, peer); - ctx.pipeline().replace(this, NettyFactory.SSL_CHANNEL_HANDLER_NAME, sslHandler); - } - else - { - // Connection use no TLS/SSL encryption, just remove the detection handler and continue without - // SslHandler in the pipeline. - ctx.pipeline().remove(this); - } - } -} diff --git a/src/java/org/apache/cassandra/net/async/OutboundConnectionIdentifier.java b/src/java/org/apache/cassandra/net/async/OutboundConnectionIdentifier.java deleted file mode 100644 index e3090657c947..000000000000 --- a/src/java/org/apache/cassandra/net/async/OutboundConnectionIdentifier.java +++ /dev/null @@ -1,194 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.cassandra.net.async; - -import com.carrotsearch.hppc.IntObjectMap; -import com.carrotsearch.hppc.IntObjectOpenHashMap; -import org.apache.cassandra.locator.InetAddressAndPort; - -/** - * Identifies an outbound messaging connection. - * - * This mainly hold the remote address and the type (small/large messages or gossip) of connection used, but with the - * additional detail that in some case (typically public EC2 address across regions) the address to which we connect - * to the remote is different from the address by which the node is known by the rest of the C*. - */ -public class OutboundConnectionIdentifier -{ - public enum ConnectionType - { - GOSSIP (0), LARGE_MESSAGE (1), SMALL_MESSAGE (2), STREAM (3); - - private final int id; - - ConnectionType(int id) - { - this.id = id; - } - - public int getId() - { - return id; - } - - private static final IntObjectMap idMap = new IntObjectOpenHashMap<>(values().length); - static - { - for (ConnectionType type : values()) - idMap.put(type.id, type); - } - - public static ConnectionType fromId(int id) - { - return idMap.get(id); - } - - } - - /** - * Memoization of the local node's broadcast address. - */ - private final InetAddressAndPort localAddr; - - /** - * The address by which the remote is identified. This may be different from {@link #remoteConnectionAddr} for - * something like EC2 public IP address which need to be used for communication between EC2 regions. - */ - private final InetAddressAndPort remoteAddr; - - /** - * The address to which we're connecting to the node (often the same as {@link #remoteAddr} but not always). - */ - private final InetAddressAndPort remoteConnectionAddr; - - private final ConnectionType connectionType; - - private OutboundConnectionIdentifier(InetAddressAndPort localAddr, - InetAddressAndPort remoteAddr, - InetAddressAndPort remoteConnectionAddr, - ConnectionType connectionType) - { - this.localAddr = localAddr; - this.remoteAddr = remoteAddr; - this.remoteConnectionAddr = remoteConnectionAddr; - this.connectionType = connectionType; - } - - private OutboundConnectionIdentifier(InetAddressAndPort localAddr, - InetAddressAndPort remoteAddr, - ConnectionType connectionType) - { - this(localAddr, remoteAddr, remoteAddr, connectionType); - } - - /** - * Creates an identifier for a small message connection and using the remote "identifying" address as its connection - * address. - */ - public static OutboundConnectionIdentifier small(InetAddressAndPort localAddr, InetAddressAndPort remoteAddr) - { - return new OutboundConnectionIdentifier(localAddr, remoteAddr, ConnectionType.SMALL_MESSAGE); - } - - /** - * Creates an identifier for a large message connection and using the remote "identifying" address as its connection - * address. - */ - public static OutboundConnectionIdentifier large(InetAddressAndPort localAddr, InetAddressAndPort remoteAddr) - { - return new OutboundConnectionIdentifier(localAddr, remoteAddr, ConnectionType.LARGE_MESSAGE); - } - - /** - * Creates an identifier for a gossip connection and using the remote "identifying" address as its connection - * address. - */ - public static OutboundConnectionIdentifier gossip(InetAddressAndPort localAddr, InetAddressAndPort remoteAddr) - { - return new OutboundConnectionIdentifier(localAddr, remoteAddr, ConnectionType.GOSSIP); - } - - /** - * Creates an identifier for a gossip connection and using the remote "identifying" address as its connection - * address. - */ - public static OutboundConnectionIdentifier stream(InetAddressAndPort localAddr, InetAddressAndPort remoteAddr) - { - return new OutboundConnectionIdentifier(localAddr, remoteAddr, ConnectionType.STREAM); - } - - /** - * Returns a newly created connection identifier to the same remote that this identifier, but using the provided - * address as connection address. - * - * @param remoteConnectionAddr the address to use for connection to the remote in the new identifier. - * @return a newly created connection identifier that differs from this one only by using {@code remoteConnectionAddr} - * as connection address to the remote. - */ - public OutboundConnectionIdentifier withNewConnectionAddress(InetAddressAndPort remoteConnectionAddr) - { - return new OutboundConnectionIdentifier(localAddr, remoteAddr, remoteConnectionAddr, connectionType); - } - - public OutboundConnectionIdentifier withNewConnectionPort(int port) - { - return new OutboundConnectionIdentifier(localAddr, InetAddressAndPort.getByAddressOverrideDefaults(remoteAddr.address, port), - InetAddressAndPort.getByAddressOverrideDefaults(remoteConnectionAddr.address, port), connectionType); - } - - /** - * The local node address. - */ - public InetAddressAndPort local() - { - return localAddr; - } - - /** - * The remote node identifying address (the one to use for anything else than connecting to the node). - */ - public InetAddressAndPort remote() - { - return remoteAddr; - } - - /** - * The remote node connection address (the one to use to actually connect to the remote, and only that). - */ - public InetAddressAndPort connectionAddress() - { - return remoteConnectionAddr; - } - - /** - * The type of this connection. - */ - ConnectionType type() - { - return connectionType; - } - - @Override - public String toString() - { - return remoteAddr.equals(remoteConnectionAddr) - ? String.format("%s (%s)", remoteAddr, connectionType) - : String.format("%s on %s (%s)", remoteAddr, remoteConnectionAddr, connectionType); - } -} diff --git a/src/java/org/apache/cassandra/net/async/OutboundConnectionParams.java b/src/java/org/apache/cassandra/net/async/OutboundConnectionParams.java deleted file mode 100644 index 64968c6cf403..000000000000 --- a/src/java/org/apache/cassandra/net/async/OutboundConnectionParams.java +++ /dev/null @@ -1,230 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.cassandra.net.async; - -import java.util.Optional; -import java.util.function.Consumer; -import java.util.function.Supplier; - -import com.google.common.base.Preconditions; - -import io.netty.channel.WriteBufferWaterMark; -import org.apache.cassandra.config.DatabaseDescriptor; -import org.apache.cassandra.config.EncryptionOptions.ServerEncryptionOptions; -import org.apache.cassandra.net.async.OutboundHandshakeHandler.HandshakeResult; -import org.apache.cassandra.utils.CoalescingStrategies.CoalescingStrategy; - -/** - * A collection of data points to be passed around for outbound connections. - */ -public class OutboundConnectionParams -{ - public static final int DEFAULT_SEND_BUFFER_SIZE = 1 << 16; - - final OutboundConnectionIdentifier connectionId; - final Consumer callback; - final ServerEncryptionOptions encryptionOptions; - final NettyFactory.Mode mode; - final boolean compress; - final Optional coalescingStrategy; - final int sendBufferSize; - final boolean tcpNoDelay; - final Supplier backlogSupplier; - final Consumer messageResultConsumer; - final WriteBufferWaterMark waterMark; - final int protocolVersion; - final int tcpConnectTimeoutInMS; - final int tcpUserTimeoutInMS; - - private OutboundConnectionParams(OutboundConnectionIdentifier connectionId, - Consumer callback, - ServerEncryptionOptions encryptionOptions, - NettyFactory.Mode mode, - boolean compress, - Optional coalescingStrategy, - int sendBufferSize, - boolean tcpNoDelay, - Supplier backlogSupplier, - Consumer messageResultConsumer, - WriteBufferWaterMark waterMark, - int protocolVersion, - int tcpConnectTimeoutInMS, - int tcpUserTimeoutInMS) - { - this.connectionId = connectionId; - this.callback = callback; - this.encryptionOptions = encryptionOptions; - this.mode = mode; - this.compress = compress; - this.coalescingStrategy = coalescingStrategy; - this.sendBufferSize = sendBufferSize; - this.tcpNoDelay = tcpNoDelay; - this.backlogSupplier = backlogSupplier; - this.messageResultConsumer = messageResultConsumer; - this.waterMark = waterMark; - this.protocolVersion = protocolVersion; - this.tcpConnectTimeoutInMS = tcpConnectTimeoutInMS; - this.tcpUserTimeoutInMS = tcpUserTimeoutInMS; - } - - public static Builder builder() - { - return new Builder(); - } - - public static Builder builder(OutboundConnectionParams params) - { - return new Builder(params); - } - - public static class Builder - { - private OutboundConnectionIdentifier connectionId; - private Consumer callback; - private ServerEncryptionOptions encryptionOptions; - private NettyFactory.Mode mode; - private boolean compress; - private Optional coalescingStrategy = Optional.empty(); - private int sendBufferSize = DEFAULT_SEND_BUFFER_SIZE; - private boolean tcpNoDelay; - private Supplier backlogSupplier; - private Consumer messageResultConsumer; - private WriteBufferWaterMark waterMark = WriteBufferWaterMark.DEFAULT; - private int protocolVersion; - private int tcpConnectTimeoutInMS; - private int tcpUserTimeoutInMS; - - private Builder() - { - this.tcpConnectTimeoutInMS = DatabaseDescriptor.getInternodeTcpConnectTimeoutInMS(); - this.tcpUserTimeoutInMS = DatabaseDescriptor.getInternodeTcpUserTimeoutInMS(); - } - - private Builder(OutboundConnectionParams params) - { - this.connectionId = params.connectionId; - this.callback = params.callback; - this.encryptionOptions = params.encryptionOptions; - this.mode = params.mode; - this.compress = params.compress; - this.coalescingStrategy = params.coalescingStrategy; - this.sendBufferSize = params.sendBufferSize; - this.tcpNoDelay = params.tcpNoDelay; - this.backlogSupplier = params.backlogSupplier; - this.messageResultConsumer = params.messageResultConsumer; - this.tcpConnectTimeoutInMS = params.tcpConnectTimeoutInMS; - this.tcpUserTimeoutInMS = params.tcpUserTimeoutInMS; - } - - public Builder connectionId(OutboundConnectionIdentifier connectionId) - { - this.connectionId = connectionId; - return this; - } - - public Builder callback(Consumer callback) - { - this.callback = callback; - return this; - } - - public Builder encryptionOptions(ServerEncryptionOptions encryptionOptions) - { - this.encryptionOptions = encryptionOptions; - return this; - } - - public Builder mode(NettyFactory.Mode mode) - { - this.mode = mode; - return this; - } - - public Builder compress(boolean compress) - { - this.compress = compress; - return this; - } - - public Builder coalescingStrategy(Optional coalescingStrategy) - { - this.coalescingStrategy = coalescingStrategy; - return this; - } - - public Builder sendBufferSize(int sendBufferSize) - { - this.sendBufferSize = sendBufferSize; - return this; - } - - public Builder tcpNoDelay(boolean tcpNoDelay) - { - this.tcpNoDelay = tcpNoDelay; - return this; - } - - public Builder backlogSupplier(Supplier backlogSupplier) - { - this.backlogSupplier = backlogSupplier; - return this; - } - - public Builder messageResultConsumer(Consumer messageResultConsumer) - { - this.messageResultConsumer = messageResultConsumer; - return this; - } - - public Builder waterMark(WriteBufferWaterMark waterMark) - { - this.waterMark = waterMark; - return this; - } - - public Builder protocolVersion(int protocolVersion) - { - this.protocolVersion = protocolVersion; - return this; - } - - public Builder tcpConnectTimeoutInMS(int tcpConnectTimeoutInMS) - { - this.tcpConnectTimeoutInMS = tcpConnectTimeoutInMS; - return this; - } - - public Builder tcpUserTimeoutInMS(int tcpUserTimeoutInMS) - { - this.tcpUserTimeoutInMS = tcpUserTimeoutInMS; - return this; - } - - public OutboundConnectionParams build() - { - Preconditions.checkArgument(protocolVersion > 0, "illegal protocol version: " + protocolVersion); - Preconditions.checkArgument(sendBufferSize > 0 && sendBufferSize < 1 << 20, "illegal send buffer size: " + sendBufferSize); - Preconditions.checkArgument(tcpUserTimeoutInMS >= 0, "tcp user timeout must be non negative: " + tcpUserTimeoutInMS); - Preconditions.checkArgument(tcpConnectTimeoutInMS > 0, "tcp connect timeout must be positive: " + tcpConnectTimeoutInMS); - - return new OutboundConnectionParams(connectionId, callback, encryptionOptions, mode, compress, coalescingStrategy, sendBufferSize, - tcpNoDelay, backlogSupplier, messageResultConsumer, waterMark, protocolVersion, tcpConnectTimeoutInMS, tcpUserTimeoutInMS); - } - } -} diff --git a/src/java/org/apache/cassandra/net/async/OutboundHandshakeHandler.java b/src/java/org/apache/cassandra/net/async/OutboundHandshakeHandler.java deleted file mode 100644 index 3ccbf49ccf3c..000000000000 --- a/src/java/org/apache/cassandra/net/async/OutboundHandshakeHandler.java +++ /dev/null @@ -1,262 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.cassandra.net.async; - -import java.util.List; -import java.util.concurrent.TimeUnit; -import java.util.function.Consumer; - -import com.google.common.annotations.VisibleForTesting; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import io.netty.buffer.ByteBuf; -import io.netty.channel.Channel; -import io.netty.channel.ChannelFuture; -import io.netty.channel.ChannelHandler; -import io.netty.channel.ChannelHandlerContext; -import io.netty.channel.ChannelInboundHandler; -import io.netty.channel.ChannelPipeline; -import io.netty.handler.codec.ByteToMessageDecoder; -import io.netty.handler.timeout.IdleStateHandler; -import io.netty.util.concurrent.Future; - -import org.apache.cassandra.config.DatabaseDescriptor; -import org.apache.cassandra.net.MessagingService; -import org.apache.cassandra.net.async.HandshakeProtocol.FirstHandshakeMessage; -import org.apache.cassandra.net.async.HandshakeProtocol.SecondHandshakeMessage; -import org.apache.cassandra.net.async.HandshakeProtocol.ThirdHandshakeMessage; - -import static org.apache.cassandra.config.Config.PROPERTY_PREFIX; - -/** - * A {@link ChannelHandler} to execute the send-side of the internode communication handshake protocol. - * As soon as the handler is added to the channel via {@link #channelActive(ChannelHandlerContext)} - * (which is only invoked if the underlying TCP connection was properly established), the {@link FirstHandshakeMessage} - * of the internode messaging protocol is automatically sent out. See {@link HandshakeProtocol} for full details - * about the internode messaging hndshake protocol. - *

- * Upon completion of the handshake (on success or fail), the {@link #callback} is invoked to let the listener - * know the result of the handshake. See {@link HandshakeResult} for details about the different result states. - *

- * This class extends {@link ByteToMessageDecoder}, which is a {@link ChannelInboundHandler}, because this handler - * waits for the peer's handshake response (the {@link SecondHandshakeMessage} of the internode messaging handshake protocol). - */ -public class OutboundHandshakeHandler extends ByteToMessageDecoder -{ - private static final Logger logger = LoggerFactory.getLogger(OutboundHandshakeHandler.class); - - /** - * The number of milliseconds to wait before closing a channel if there has been no progress (when there is - * data to be sent).See {@link IdleStateHandler} and {@link MessageOutHandler#userEventTriggered(ChannelHandlerContext, Object)}. - */ - private static final long DEFAULT_WRITE_IDLE_MS = TimeUnit.SECONDS.toMillis(10); - private static final String WRITE_IDLE_PROPERTY = PROPERTY_PREFIX + "outbound_write_idle_ms"; - private static final long WRITE_IDLE_MS = Long.getLong(WRITE_IDLE_PROPERTY, DEFAULT_WRITE_IDLE_MS); - - private final OutboundConnectionIdentifier connectionId; - - /** - * The expected messaging service version to use. - */ - private final int messagingVersion; - - /** - * A function to invoke upon completion of the attempt, success or failure, to connect to the peer. - */ - private final Consumer callback; - private final NettyFactory.Mode mode; - private final OutboundConnectionParams params; - - OutboundHandshakeHandler(OutboundConnectionParams params) - { - this.params = params; - this.connectionId = params.connectionId; - this.messagingVersion = params.protocolVersion; - this.callback = params.callback; - this.mode = params.mode; - } - - /** - * {@inheritDoc} - * - * Invoked when the channel is made active, and sends out the {@link FirstHandshakeMessage}. - * In the case of streaming, we do not require a full bi-directional handshake; the initial message, - * containing the streaming protocol version, is all that is required. - */ - @Override - public void channelActive(final ChannelHandlerContext ctx) throws Exception - { - FirstHandshakeMessage msg = new FirstHandshakeMessage(messagingVersion, mode, params.compress); - logger.trace("starting handshake with peer {}, msg = {}", connectionId.connectionAddress(), msg); - ctx.writeAndFlush(msg.encode(ctx.alloc())).addListener(future -> firstHandshakeMessageListener(future, ctx)); - - if (mode == NettyFactory.Mode.STREAMING) - ctx.pipeline().remove(this); - - ctx.fireChannelActive(); - } - - /** - * A simple listener to make sure we could send the {@link FirstHandshakeMessage} to the socket, - * and fail the handshake attempt if we could not (for example, maybe we could create the TCP socket, but then - * the connection gets closed for some reason). - */ - void firstHandshakeMessageListener(Future future, ChannelHandlerContext ctx) - { - if (future.isSuccess()) - return; - - ChannelFuture channelFuture = (ChannelFuture)future; - exceptionCaught(ctx, channelFuture.cause()); - } - - /** - * {@inheritDoc} - * - * Invoked when we get the response back from the peer, which should contain the second message of the internode messaging handshake. - *

- * If the peer's protocol version does not equal what we were expecting, immediately close the channel (and socket); - * do *not* send out the third message of the internode messaging handshake. - * We will reconnect on the appropriate protocol version. - */ - @Override - protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) throws Exception - { - SecondHandshakeMessage msg = SecondHandshakeMessage.maybeDecode(in); - if (msg == null) - return; - - logger.trace("received second handshake message from peer {}, msg = {}", connectionId.connectionAddress(), msg); - final int peerMessagingVersion = msg.messagingVersion; - - // we expected a higher protocol version, but it was actually lower - if (messagingVersion > peerMessagingVersion) - { - logger.trace("peer's max version is {}; will reconnect with that version", peerMessagingVersion); - try - { - if (DatabaseDescriptor.getSeeds().contains(connectionId.remote())) - logger.warn("Seed gossip version is {}; will not connect with that version", peerMessagingVersion); - } - catch (Throwable e) - { - // If invalid yaml has been added to the config since startup, getSeeds() will throw an AssertionError - // Additionally, third party seed providers may throw exceptions if network is flakey. - // Regardless of what's thrown, we must catch it, disconnect, and try again - logger.warn("failed to reread yaml (on trying to connect to a seed): {}", e.getLocalizedMessage()); - } - ctx.close(); - callback.accept(HandshakeResult.disconnect(peerMessagingVersion)); - return; - } - // we anticipate a version that is lower than what peer is actually running - else if (messagingVersion < peerMessagingVersion && messagingVersion < MessagingService.current_version) - { - logger.trace("peer has a higher max version than expected {} (previous value {})", peerMessagingVersion, messagingVersion); - ctx.close(); - callback.accept(HandshakeResult.disconnect(peerMessagingVersion)); - return; - } - - try - { - ctx.writeAndFlush(new ThirdHandshakeMessage(peerMessagingVersion, connectionId.local()).encode(ctx.alloc())); - ChannelWriter channelWriter = setupPipeline(ctx.channel(), peerMessagingVersion); - callback.accept(HandshakeResult.success(channelWriter, peerMessagingVersion)); - } - catch (Exception e) - { - logger.info("failed to finalize internode messaging handshake", e); - ctx.close(); - callback.accept(HandshakeResult.failed()); - } - } - - @VisibleForTesting - ChannelWriter setupPipeline(Channel channel, int messagingVersion) - { - ChannelPipeline pipeline = channel.pipeline(); - pipeline.addLast("idleWriteHandler", new IdleStateHandler(true, 0, WRITE_IDLE_MS, 0, TimeUnit.MILLISECONDS)); - if (params.compress) - pipeline.addLast(NettyFactory.OUTBOUND_COMPRESSOR_HANDLER_NAME, NettyFactory.createLz4Encoder(messagingVersion)); - - ChannelWriter channelWriter = ChannelWriter.create(channel, params.messageResultConsumer, params.coalescingStrategy); - pipeline.addLast("messageOutHandler", new MessageOutHandler(connectionId, messagingVersion, channelWriter, params.backlogSupplier)); - pipeline.remove(this); - return channelWriter; - } - - @Override - public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) - { - logger.error("Failed to properly handshake with peer {}. Closing the channel.", connectionId, cause); - ctx.close(); - callback.accept(HandshakeResult.failed()); - } - - /** - * The result of the handshake. Handshake has 3 possible outcomes: - * 1) it can be successful, in which case the channel and version to used is returned in this result. - * 2) we may decide to disconnect to reconnect with another protocol version (namely, the version is passed in this result). - * 3) we can have a negotiation failure for an unknown reason. (#sadtrombone) - */ - public static class HandshakeResult - { - static final int UNKNOWN_PROTOCOL_VERSION = -1; - - /** - * Describes the result of receiving the response back from the peer (Message 2 of the handshake) - * and implies an action that should be taken. - */ - enum Outcome - { - SUCCESS, DISCONNECT, NEGOTIATION_FAILURE - } - - /** The channel for the connection, only set for successful handshake. */ - final ChannelWriter channelWriter; - /** The version negotiated with the peer. Set unless this is a {@link Outcome#NEGOTIATION_FAILURE}. */ - final int negotiatedMessagingVersion; - /** The handshake {@link Outcome}. */ - final Outcome outcome; - - private HandshakeResult(ChannelWriter channelWriter, int negotiatedMessagingVersion, Outcome outcome) - { - this.channelWriter = channelWriter; - this.negotiatedMessagingVersion = negotiatedMessagingVersion; - this.outcome = outcome; - } - - static HandshakeResult success(ChannelWriter channel, int negotiatedMessagingVersion) - { - return new HandshakeResult(channel, negotiatedMessagingVersion, Outcome.SUCCESS); - } - - static HandshakeResult disconnect(int negotiatedMessagingVersion) - { - return new HandshakeResult(null, negotiatedMessagingVersion, Outcome.DISCONNECT); - } - - static HandshakeResult failed() - { - return new HandshakeResult(null, UNKNOWN_PROTOCOL_VERSION, Outcome.NEGOTIATION_FAILURE); - } - } -} diff --git a/src/java/org/apache/cassandra/net/async/OutboundMessagingConnection.java b/src/java/org/apache/cassandra/net/async/OutboundMessagingConnection.java deleted file mode 100644 index 265ece9b26ee..000000000000 --- a/src/java/org/apache/cassandra/net/async/OutboundMessagingConnection.java +++ /dev/null @@ -1,747 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.cassandra.net.async; - -import java.io.IOException; -import java.util.Optional; -import java.util.Queue; -import java.util.concurrent.ConcurrentLinkedQueue; -import java.util.concurrent.ScheduledExecutorService; -import java.util.concurrent.ScheduledFuture; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicLong; -import java.util.concurrent.atomic.AtomicReference; - -import com.google.common.annotations.VisibleForTesting; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import io.netty.bootstrap.Bootstrap; -import io.netty.channel.Channel; -import io.netty.channel.ChannelFuture; -import io.netty.util.concurrent.Future; -import org.apache.cassandra.auth.IInternodeAuthenticator; -import org.apache.cassandra.concurrent.ScheduledExecutors; -import org.apache.cassandra.config.Config; -import org.apache.cassandra.config.DatabaseDescriptor; -import org.apache.cassandra.config.EncryptionOptions.ServerEncryptionOptions; -import org.apache.cassandra.locator.InetAddressAndPort; -import org.apache.cassandra.net.MessageOut; -import org.apache.cassandra.net.MessagingService; -import org.apache.cassandra.net.async.NettyFactory.Mode; -import org.apache.cassandra.net.async.OutboundHandshakeHandler.HandshakeResult; -import org.apache.cassandra.utils.CoalescingStrategies; -import org.apache.cassandra.utils.CoalescingStrategies.CoalescingStrategy; -import org.apache.cassandra.utils.JVMStabilityInspector; -import org.apache.cassandra.utils.NoSpamLogger; - -/** - * Represents one connection to a peer, and handles the state transistions on the connection and the netty {@link Channel} - * The underlying socket is not opened until explicitly requested (by sending a message). - * - * The basic setup for the channel is like this: a message is requested to be sent via {@link #sendMessage(MessageOut, int)}. - * If the channel is not established, then we need to create it (obviously). To prevent multiple threads from creating - * independent connections, they attempt to update the {@link #state}; one thread will win the race and create the connection. - * Upon sucessfully setting up the connection/channel, the {@link #state} will be updated again (to {@link State#READY}, - * which indicates to other threads that the channel is ready for business and can be written to. - * - */ -public class OutboundMessagingConnection -{ - static final Logger logger = LoggerFactory.getLogger(OutboundMessagingConnection.class); - private static final NoSpamLogger errorLogger = NoSpamLogger.getLogger(logger, 10, TimeUnit.SECONDS); - - private static final String INTRADC_TCP_NODELAY_PROPERTY = Config.PROPERTY_PREFIX + "otc_intradc_tcp_nodelay"; - - /** - * Enabled/disable TCP_NODELAY for intradc connections. Defaults to enabled. - */ - private static final boolean INTRADC_TCP_NODELAY = Boolean.parseBoolean(System.getProperty(INTRADC_TCP_NODELAY_PROPERTY, "true")); - - /** - * Number of milliseconds between connection createRetry attempts. - */ - private static final int OPEN_RETRY_DELAY_MS = 100; - - /** - * A minimum number of milliseconds to wait for a connection (TCP socket connect + handshake) - */ - private static final int MINIMUM_CONNECT_TIMEOUT_MS = 2000; - private final IInternodeAuthenticator authenticator; - - /** - * Describes this instance's ability to send messages into it's Netty {@link Channel}. - */ - enum State - { - /** waiting to create the connection */ - NOT_READY, - /** we've started to create the connection/channel */ - CREATING_CHANNEL, - /** channel is established and we can send messages */ - READY, - /** a dead state which should not be transitioned away from */ - CLOSED - } - - /** - * Backlog to hold messages passed by upstream threads while the Netty {@link Channel} is being set up or recreated. - */ - private final Queue backlog; - - /** - * Reference to a {@link ScheduledExecutorService} rther than directly depending on something like {@link ScheduledExecutors}. - */ - private final ScheduledExecutorService scheduledExecutor; - - final AtomicLong droppedMessageCount; - final AtomicLong completedMessageCount; - - private volatile OutboundConnectionIdentifier connectionId; - - private final ServerEncryptionOptions encryptionOptions; - - /** - * A future for retrying connections. Bear in mind that this future does not execute in the - * netty event event loop, so there's some races to be careful of. - */ - private volatile ScheduledFuture connectionRetryFuture; - - /** - * A future for notifying when the timeout for creating the connection and negotiating the handshake has elapsed. - * It will be cancelled when the channel is established correctly. Bear in mind that this future does not execute in the - * netty event event loop, so there's some races to be careful of. - */ - private volatile ScheduledFuture connectionTimeoutFuture; - - private final AtomicReference state; - - private final Optional coalescingStrategy; - - /** - * A running count of the number of times we've tried to create a connection. - */ - private volatile int connectAttemptCount; - - /** - * The netty channel, once a socket connection is established; it won't be in it's normal working state until the handshake is complete. - */ - private volatile ChannelWriter channelWriter; - - /** - * the target protocol version to communicate to the peer with, discovered/negotiated via handshaking - */ - private int targetVersion; - - OutboundMessagingConnection(OutboundConnectionIdentifier connectionId, - ServerEncryptionOptions encryptionOptions, - Optional coalescingStrategy, - IInternodeAuthenticator authenticator) - { - this(connectionId, encryptionOptions, coalescingStrategy, authenticator, ScheduledExecutors.scheduledFastTasks); - } - - @VisibleForTesting - OutboundMessagingConnection(OutboundConnectionIdentifier connectionId, - ServerEncryptionOptions encryptionOptions, - Optional coalescingStrategy, - IInternodeAuthenticator authenticator, - ScheduledExecutorService sceduledExecutor) - { - this.connectionId = connectionId; - this.encryptionOptions = encryptionOptions; - this.authenticator = authenticator; - backlog = new ConcurrentLinkedQueue<>(); - droppedMessageCount = new AtomicLong(0); - completedMessageCount = new AtomicLong(0); - state = new AtomicReference<>(State.NOT_READY); - this.scheduledExecutor = sceduledExecutor; - this.coalescingStrategy = coalescingStrategy; - - // We want to use the most precise protocol version we know because while there is version detection on connect(), - // the target version might be accessed by the pool (in getConnection()) before we actually connect (as we - // only connect when the first message is submitted). Note however that the only case where we'll connect - // without knowing the true version of a node is if that node is a seed (otherwise, we can't know a node - // unless it has been gossiped to us or it has connected to us, and in both cases that will set the version). - // In that case we won't rely on that targetVersion before we're actually connected and so the version - // detection in connect() will do its job. - targetVersion = MessagingService.instance().getVersion(connectionId.remote()); - } - - /** - * If the connection is set up and ready to use (the normal case), simply send the message to it and return. - * Otherwise, one lucky thread is selected to create the Channel, while other threads just add the {@code msg} to - * the backlog queue. - * - * @return true if the message was accepted by the {@link #channelWriter}; else false if it was not accepted - * and added to the backlog or the channel is {@link State#CLOSED}. See documentation in {@link ChannelWriter} and - * {@link MessageOutHandler} how the backlogged messages get consumed. - */ - boolean sendMessage(MessageOut msg, int id) - { - return sendMessage(new QueuedMessage(msg, id)); - } - - boolean sendMessage(QueuedMessage queuedMessage) - { - State state = this.state.get(); - if (state == State.READY) - { - if (channelWriter.write(queuedMessage, false)) - return true; - - backlog.add(queuedMessage); - return false; - } - else if (state == State.CLOSED) - { - errorLogger.warn("trying to write message to a closed connection"); - return false; - } - else - { - backlog.add(queuedMessage); - connect(); - return true; - } - } - - /** - * Initiate all the actions required to establish a working, valid connection. This includes - * opening the socket, negotiating the internode messaging handshake, and setting up the working - * Netty {@link Channel}. However, this method will not block for all those actions: it will only - * kick off the connection attempt as everything is asynchronous. - *

- * Threads compete to update the {@link #state} field to {@link State#CREATING_CHANNEL} to ensure only one - * connection is attempted at a time. - * - * @return true if kicking off the connection attempt was started by this thread; else, false. - */ - public boolean connect() - { - // try to be the winning thread to create the channel - if (!state.compareAndSet(State.NOT_READY, State.CREATING_CHANNEL)) - return false; - - // clean up any lingering connection attempts - if (connectionTimeoutFuture != null) - { - connectionTimeoutFuture.cancel(false); - connectionTimeoutFuture = null; - } - - return tryConnect(); - } - - private boolean tryConnect() - { - if (state.get() != State.CREATING_CHANNEL) - return false; - - logger.debug("connection attempt {} to {}", connectAttemptCount, connectionId); - - - InetAddressAndPort remote = connectionId.remote(); - if (!authenticator.authenticate(remote.address, remote.port)) - { - logger.warn("Internode auth failed connecting to {}", connectionId); - //Remove the connection pool and other thread so messages aren't queued - MessagingService.instance().destroyConnectionPool(remote); - - // don't update the state field as destroyConnectionPool() *should* call OMC.close() - // on all the connections in the OMP for the remoteAddress - return false; - } - - boolean compress = shouldCompressConnection(connectionId.local(), connectionId.remote()); - maybeUpdateConnectionId(); - Bootstrap bootstrap = buildBootstrap(compress); - - ChannelFuture connectFuture = bootstrap.connect(); - connectFuture.addListener(this::connectCallback); - - long timeout = Math.max(MINIMUM_CONNECT_TIMEOUT_MS, DatabaseDescriptor.getRpcTimeout()); - if (connectionTimeoutFuture == null || connectionTimeoutFuture.isDone()) - connectionTimeoutFuture = scheduledExecutor.schedule(() -> connectionTimeout(connectFuture), timeout, TimeUnit.MILLISECONDS); - return true; - } - - @VisibleForTesting - static boolean shouldCompressConnection(InetAddressAndPort localHost, InetAddressAndPort remoteHost) - { - return (DatabaseDescriptor.internodeCompression() == Config.InternodeCompression.all) - || ((DatabaseDescriptor.internodeCompression() == Config.InternodeCompression.dc) && !isLocalDC(localHost, remoteHost)); - } - - /** - * After a bounce we won't necessarily know the peer's version, so we assume the peer is at least 4.0 - * and thus using a single port for secure and non-secure communication. However, during a rolling upgrade from - * 3.0.x/3.x to 4.0, the not-yet upgraded peer is still listening on separate ports, but we don't know the peer's - * version until we can successfully connect. Fortunately, the peer can connect to this node, at which point - * we'll grab it's version. We then use that knowledge to use the {@link Config#ssl_storage_port} to connect on, - * and to do that we need to update some member fields in this instance. - * - * Note: can be removed at 5.0 - */ - void maybeUpdateConnectionId() - { - if (encryptionOptions != null) - { - int version = MessagingService.instance().getVersion(connectionId.remote()); - if (version < targetVersion) - { - targetVersion = version; - int port = MessagingService.instance().portFor(connectionId.remote()); - connectionId = connectionId.withNewConnectionPort(port); - logger.debug("changing connectionId to {}, with a different port for secure communication, because peer version is {}", connectionId, version); - } - } - } - - private Bootstrap buildBootstrap(boolean compress) - { - boolean tcpNoDelay = isLocalDC(connectionId.local(), connectionId.remote()) ? INTRADC_TCP_NODELAY : DatabaseDescriptor.getInterDCTcpNoDelay(); - int sendBufferSize = DatabaseDescriptor.getInternodeSendBufferSize() > 0 - ? DatabaseDescriptor.getInternodeSendBufferSize() - : OutboundConnectionParams.DEFAULT_SEND_BUFFER_SIZE; - - int tcpConnectTimeout = DatabaseDescriptor.getInternodeTcpConnectTimeoutInMS(); - int tcpUserTimeout = DatabaseDescriptor.getInternodeTcpUserTimeoutInMS(); - - OutboundConnectionParams params = OutboundConnectionParams.builder() - .connectionId(connectionId) - .callback(this::finishHandshake) - .encryptionOptions(encryptionOptions) - .mode(Mode.MESSAGING) - .compress(compress) - .coalescingStrategy(coalescingStrategy) - .sendBufferSize(sendBufferSize) - .tcpNoDelay(tcpNoDelay) - .tcpConnectTimeoutInMS(tcpConnectTimeout) - .tcpUserTimeoutInMS(tcpUserTimeout) - .backlogSupplier(() -> nextBackloggedMessage()) - .messageResultConsumer(this::handleMessageResult) - .protocolVersion(targetVersion) - .build(); - - return NettyFactory.instance.createOutboundBootstrap(params); - } - - private QueuedMessage nextBackloggedMessage() - { - QueuedMessage msg = backlog.poll(); - if (msg == null) - return null; - - if (!msg.isTimedOut()) - return msg; - - if (msg.shouldRetry()) - return msg.createRetry(); - - droppedMessageCount.incrementAndGet(); - return null; - } - - static boolean isLocalDC(InetAddressAndPort localHost, InetAddressAndPort remoteHost) - { - String remoteDC = DatabaseDescriptor.getEndpointSnitch().getDatacenter(remoteHost); - String localDC = DatabaseDescriptor.getEndpointSnitch().getDatacenter(localHost); - return remoteDC != null && remoteDC.equals(localDC); - } - - /** - * Handles the callback of the TCP connection attempt (not including the handshake negotiation!), and really all - * we're handling here is the TCP connection failures. On failure, we close the channel (which should disconnect - * the socket, if connected). If there was an {@link IOException} while trying to connect, the connection will be - * retried after a short delay. - *

- * This method does not alter the {@link #state} as it's only evaluating the TCP connect, not TCP connect and handshake. - * Thus, {@link #finishHandshake(HandshakeResult)} will handle any necessary state updates. - *

- * Note: this method is called from the event loop, so be careful wrt thread visibility - * - * @return true iff the TCP connection was established and the {@link #state} is not {@link State#CLOSED}; else false. - */ - @VisibleForTesting - boolean connectCallback(Future future) - { - ChannelFuture channelFuture = (ChannelFuture)future; - - // make sure this instance is not (terminally) closed - if (state.get() == State.CLOSED) - { - channelFuture.channel().close(); - return false; - } - - // this is the success state - final Throwable cause = future.cause(); - if (cause == null) - { - connectAttemptCount = 0; - return true; - } - - setStateIfNotClosed(state, State.NOT_READY); - if (cause instanceof IOException) - { - logger.trace("unable to connect on attempt {} to {}", connectAttemptCount, connectionId, cause); - connectAttemptCount++; - connectionRetryFuture = scheduledExecutor.schedule(this::connect, OPEN_RETRY_DELAY_MS * connectAttemptCount, TimeUnit.MILLISECONDS); - } - else - { - JVMStabilityInspector.inspectThrowable(cause); - logger.error("non-IO error attempting to connect to {}", connectionId, cause); - } - return false; - } - - /** - * A callback for handling timeouts when creating a connection/negotiating the handshake. - *

- * Note: this method is *not* invoked from the netty event loop, - * so there's an inherent race with {@link #finishHandshake(HandshakeResult)}, - * as well as any possible connect() reattempts (a seemingly remote race condition, however). - * Therefore, this function tries to lose any races, as much as possible. - * - * @return true if there was a timeout on the connect/handshake; else false. - */ - boolean connectionTimeout(ChannelFuture channelFuture) - { - if (connectionRetryFuture != null) - { - connectionRetryFuture.cancel(false); - connectionRetryFuture = null; - } - connectAttemptCount = 0; - State initialState = state.get(); - if (initialState == State.CLOSED) - return true; - - if (initialState != State.READY) - { - logger.debug("timed out while trying to connect to {}", connectionId); - - channelFuture.channel().close(); - // a last-ditch attempt to let finishHandshake() win the race - if (state.compareAndSet(initialState, State.NOT_READY)) - { - backlog.clear(); - return true; - } - } - return false; - } - - /** - * Process the results of the handshake negotiation. - *

- * Note: this method will be invoked from the netty event loop, - * so there's an inherent race with {@link #connectionTimeout(ChannelFuture)}. - */ - void finishHandshake(HandshakeResult result) - { - // clean up the connector instances before changing the state - if (connectionTimeoutFuture != null) - { - connectionTimeoutFuture.cancel(false); - connectionTimeoutFuture = null; - } - if (connectionRetryFuture != null) - { - connectionRetryFuture.cancel(false); - connectionRetryFuture = null; - } - connectAttemptCount = 0; - - if (result.negotiatedMessagingVersion != HandshakeResult.UNKNOWN_PROTOCOL_VERSION) - { - targetVersion = result.negotiatedMessagingVersion; - MessagingService.instance().setVersion(connectionId.remote(), targetVersion); - } - - switch (result.outcome) - { - case SUCCESS: - assert result.channelWriter != null; - logger.debug("successfully connected to {}, compress = {}, coalescing = {}", connectionId, - shouldCompressConnection(connectionId.local(), connectionId.remote()), - coalescingStrategy.isPresent() ? coalescingStrategy.get() : CoalescingStrategies.Strategy.DISABLED); - if (state.get() == State.CLOSED) - { - result.channelWriter.close(); - backlog.clear(); - break; - } - channelWriter = result.channelWriter; - // drain the backlog to the channel - channelWriter.writeBacklog(backlog, true); - // change the state so newly incoming messages can be sent to the channel (without adding to the backlog) - setStateIfNotClosed(state, State.READY); - // ship out any stragglers that got added to the backlog - channelWriter.writeBacklog(backlog, true); - break; - case DISCONNECT: - reconnect(); - break; - case NEGOTIATION_FAILURE: - setStateIfNotClosed(state, State.NOT_READY); - backlog.clear(); - break; - default: - throw new IllegalArgumentException("unhandled result type: " + result.outcome); - } - } - - @VisibleForTesting - static boolean setStateIfNotClosed(AtomicReference state, State newState) - { - State s = state.get(); - if (s == State.CLOSED) - return false; - state.set(newState); - return true; - } - - int getTargetVersion() - { - return targetVersion; - } - - /** - * Handles the result of each message sent. - * - * Note: this function is expected to be invoked on the netty event loop. Also, do not retain any state from - * the input {@code messageResult}. - */ - void handleMessageResult(MessageResult messageResult) - { - completedMessageCount.incrementAndGet(); - - // checking the cause() is an optimized way to tell if the operation was successful (as the cause will be null) - // Note that ExpiredException is just a marker for timeout-ed message we're dropping, but as we already - // incremented the dropped message count in MessageOutHandler, we have nothing to do. - Throwable cause = messageResult.future.cause(); - if (cause == null) - return; - - if (cause instanceof ExpiredException) - { - droppedMessageCount.incrementAndGet(); - return; - } - - JVMStabilityInspector.inspectThrowable(cause); - - if (cause instanceof IOException || cause.getCause() instanceof IOException) - { - ChannelWriter writer = messageResult.writer; - if (writer.shouldPurgeBacklog()) - purgeBacklog(); - - // This writer needs to be closed and we need to trigger a reconnection. We really only want to do that - // once for this channel however (and again, no race because we're on the netty event loop). - if (!writer.isClosed() && messageResult.allowReconnect) - { - reconnect(); - writer.close(); - } - - QueuedMessage msg = messageResult.msg; - if (msg != null && msg.shouldRetry()) - { - sendMessage(msg.createRetry()); - } - } - else if (messageResult.future.isCancelled()) - { - // Someone cancelled the future, which we assume meant it doesn't want the message to be sent if it hasn't - // yet. Just ignore. - } - else - { - // Non IO exceptions are likely a programming error so let's not silence them - logger.error("Unexpected error writing on " + connectionId, cause); - } - } - - /** - * Change the IP address on which we connect to the peer. We will attempt to connect to the new address if there - * was a previous connection, and new incoming messages as well as existing {@link #backlog} messages will be sent there. - * Any outstanding messages in the existing channel will still be sent to the previous address (we won't/can't move them from - * one channel to another). - */ - void reconnectWithNewIp(InetAddressAndPort newAddr) - { - State currentState = state.get(); - - // if we're closed, ignore the request - if (currentState == State.CLOSED) - return; - - // capture a reference to the current channel, in case it gets swapped out before we can call close() on it - ChannelWriter currentChannel = channelWriter; - connectionId = connectionId.withNewConnectionAddress(newAddr); - - if (currentState != State.NOT_READY) - reconnect(); - - // lastly, push through anything remaining in the existing channel. - if (currentChannel != null) - currentChannel.softClose(); - } - - /** - * Sets the state properly so {@link #connect()} can attempt to reconnect. - */ - void reconnect() - { - if (setStateIfNotClosed(state, State.NOT_READY)) - connect(); - } - - void purgeBacklog() - { - backlog.clear(); - } - - public void close(boolean softClose) - { - state.set(State.CLOSED); - - if (connectionTimeoutFuture != null) - { - connectionTimeoutFuture.cancel(false); - connectionTimeoutFuture = null; - } - - // drain the backlog - if (channelWriter != null) - { - if (softClose) - { - channelWriter.writeBacklog(backlog, false); - channelWriter.softClose(); - } - else - { - backlog.clear(); - channelWriter.close(); - } - - channelWriter = null; - } - } - - @Override - public String toString() - { - return connectionId.toString(); - } - - public Integer getPendingMessages() - { - int pending = backlog.size(); - ChannelWriter chan = channelWriter; - if (chan != null) - pending += (int)chan.pendingMessageCount(); - return pending; - } - - public Long getCompletedMessages() - { - return completedMessageCount.get(); - } - - public Long getDroppedMessages() - { - return droppedMessageCount.get(); - } - - /* - methods specific to testing follow - */ - - @VisibleForTesting - int backlogSize() - { - return backlog.size(); - } - - @VisibleForTesting - void addToBacklog(QueuedMessage msg) - { - backlog.add(msg); - } - - @VisibleForTesting - void setChannelWriter(ChannelWriter channelWriter) - { - this.channelWriter = channelWriter; - } - - @VisibleForTesting - ChannelWriter getChannelWriter() - { - return channelWriter; - } - - @VisibleForTesting - void setState(State state) - { - this.state.set(state); - } - - @VisibleForTesting - State getState() - { - return state.get(); - } - - @VisibleForTesting - void setTargetVersion(int targetVersion) - { - this.targetVersion = targetVersion; - } - - @VisibleForTesting - OutboundConnectionIdentifier getConnectionId() - { - return connectionId; - } - - @VisibleForTesting - void setConnectionTimeoutFuture(ScheduledFuture connectionTimeoutFuture) - { - this.connectionTimeoutFuture = connectionTimeoutFuture; - } - - @VisibleForTesting - ScheduledFuture getConnectionTimeoutFuture() - { - return connectionTimeoutFuture; - } - - public boolean isConnected() - { - return state.get() == State.READY; - } -} \ No newline at end of file diff --git a/src/java/org/apache/cassandra/net/async/OutboundMessagingPool.java b/src/java/org/apache/cassandra/net/async/OutboundMessagingPool.java deleted file mode 100644 index 14650a74d3ff..000000000000 --- a/src/java/org/apache/cassandra/net/async/OutboundMessagingPool.java +++ /dev/null @@ -1,177 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.cassandra.net.async; - -import java.util.Optional; - -import com.google.common.annotations.VisibleForTesting; - -import org.apache.cassandra.auth.IInternodeAuthenticator; -import org.apache.cassandra.concurrent.Stage; -import org.apache.cassandra.config.Config; -import org.apache.cassandra.config.DatabaseDescriptor; -import org.apache.cassandra.config.EncryptionOptions.ServerEncryptionOptions; -import org.apache.cassandra.locator.InetAddressAndPort; -import org.apache.cassandra.metrics.ConnectionMetrics; -import org.apache.cassandra.net.BackPressureState; -import org.apache.cassandra.net.MessageOut; -import org.apache.cassandra.net.async.OutboundConnectionIdentifier.ConnectionType; -import org.apache.cassandra.utils.CoalescingStrategies; -import org.apache.cassandra.utils.CoalescingStrategies.CoalescingStrategy; - -/** - * Groups a set of outbound connections to a given peer, and routes outgoing messages to the appropriate connection - * (based upon message's type or size). Contains a {@link OutboundMessagingConnection} for each of the - * {@link ConnectionType} type. - */ -public class OutboundMessagingPool -{ - @VisibleForTesting - static final long LARGE_MESSAGE_THRESHOLD = Long.getLong(Config.PROPERTY_PREFIX + "otcp_large_message_threshold", 1024 * 64); - - private final ConnectionMetrics metrics; - private final BackPressureState backPressureState; - - public OutboundMessagingConnection gossipChannel; - public OutboundMessagingConnection largeMessageChannel; - public OutboundMessagingConnection smallMessageChannel; - - /** - * An override address on which to communicate with the peer. Typically used for something like EC2 public IP addresses - * which need to be used for communication between EC2 regions. - */ - private InetAddressAndPort preferredRemoteAddr; - - public OutboundMessagingPool(InetAddressAndPort remoteAddr, InetAddressAndPort localAddr, ServerEncryptionOptions encryptionOptions, - BackPressureState backPressureState, IInternodeAuthenticator authenticator) - { - preferredRemoteAddr = remoteAddr; - this.backPressureState = backPressureState; - metrics = new ConnectionMetrics(localAddr, this); - - - smallMessageChannel = new OutboundMessagingConnection(OutboundConnectionIdentifier.small(localAddr, preferredRemoteAddr), - encryptionOptions, coalescingStrategy(remoteAddr), authenticator); - largeMessageChannel = new OutboundMessagingConnection(OutboundConnectionIdentifier.large(localAddr, preferredRemoteAddr), - encryptionOptions, coalescingStrategy(remoteAddr), authenticator); - - // don't attempt coalesce the gossip messages, just ship them out asap (let's not anger the FD on any peer node by any artificial delays) - gossipChannel = new OutboundMessagingConnection(OutboundConnectionIdentifier.gossip(localAddr, preferredRemoteAddr), - encryptionOptions, Optional.empty(), authenticator); - } - - private static Optional coalescingStrategy(InetAddressAndPort remoteAddr) - { - String strategyName = DatabaseDescriptor.getOtcCoalescingStrategy(); - String displayName = remoteAddr.toString(); - return CoalescingStrategies.newCoalescingStrategy(strategyName, - DatabaseDescriptor.getOtcCoalescingWindow(), - OutboundMessagingConnection.logger, - displayName); - - } - - public BackPressureState getBackPressureState() - { - return backPressureState; - } - - public void sendMessage(MessageOut msg, int id) - { - getConnection(msg).sendMessage(msg, id); - } - - @VisibleForTesting - public OutboundMessagingConnection getConnection(MessageOut msg) - { - if (msg.connectionType == null) - { - // optimize for the common path (the small message channel) - if (Stage.GOSSIP != msg.getStage()) - { - return msg.serializedSize(smallMessageChannel.getTargetVersion()) < LARGE_MESSAGE_THRESHOLD - ? smallMessageChannel - : largeMessageChannel; - } - return gossipChannel; - } - else - { - return getConnection(msg.connectionType); - } - } - - /** - * Reconnect to the peer using the given {@code addr}. Outstanding messages in each channel will be sent on the - * current channel. Typically this function is used for something like EC2 public IP addresses which need to be used - * for communication between EC2 regions. - * - * @param addr IP Address to use (and prefer) going forward for connecting to the peer - */ - public void reconnectWithNewIp(InetAddressAndPort addr) - { - preferredRemoteAddr = addr; - gossipChannel.reconnectWithNewIp(addr); - largeMessageChannel.reconnectWithNewIp(addr); - smallMessageChannel.reconnectWithNewIp(addr); - } - - /** - * Close each netty channel and it's socket. - * - * @param softClose {@code true} if existing messages in the queue should be sent before closing. - */ - public void close(boolean softClose) - { - gossipChannel.close(softClose); - largeMessageChannel.close(softClose); - smallMessageChannel.close(softClose); - } - - @VisibleForTesting - final OutboundMessagingConnection getConnection(ConnectionType connectionType) - { - switch (connectionType) - { - case SMALL_MESSAGE: - return smallMessageChannel; - case LARGE_MESSAGE: - return largeMessageChannel; - case GOSSIP: - return gossipChannel; - default: - throw new IllegalArgumentException("unsupported connection type: " + connectionType); - } - } - - public void incrementTimeout() - { - metrics.timeouts.mark(); - } - - public long getTimeouts() - { - return metrics.timeouts.getCount(); - } - - public InetAddressAndPort getPreferredRemoteAddr() - { - return preferredRemoteAddr; - } -} diff --git a/src/java/org/apache/cassandra/net/async/QueuedMessage.java b/src/java/org/apache/cassandra/net/async/QueuedMessage.java deleted file mode 100644 index 28e4ba47f5aa..000000000000 --- a/src/java/org/apache/cassandra/net/async/QueuedMessage.java +++ /dev/null @@ -1,75 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.cassandra.net.async; - -import java.util.concurrent.TimeUnit; - -import com.google.common.annotations.VisibleForTesting; - -import org.apache.cassandra.net.MessageOut; -import org.apache.cassandra.net.MessagingService; -import org.apache.cassandra.utils.CoalescingStrategies; - -/** - * A wrapper for outbound messages. All messages will be retried once. - */ -public class QueuedMessage implements CoalescingStrategies.Coalescable -{ - public final MessageOut message; - public final int id; - public final long timestampNanos; - public final boolean droppable; - private final boolean retryable; - - public QueuedMessage(MessageOut message, int id) - { - this(message, id, System.nanoTime(), MessagingService.DROPPABLE_VERBS.contains(message.verb), true); - } - - @VisibleForTesting - public QueuedMessage(MessageOut message, int id, long timestampNanos, boolean droppable, boolean retryable) - { - this.message = message; - this.id = id; - this.timestampNanos = timestampNanos; - this.droppable = droppable; - this.retryable = retryable; - } - - /** don't drop a non-droppable message just because it's timestamp is expired */ - public boolean isTimedOut() - { - return droppable && timestampNanos < System.nanoTime() - TimeUnit.MILLISECONDS.toNanos(message.getTimeout()); - } - - public boolean shouldRetry() - { - return retryable; - } - - public QueuedMessage createRetry() - { - return new QueuedMessage(message, id, System.nanoTime(), droppable, false); - } - - public long timestampNanos() - { - return timestampNanos; - } -} diff --git a/src/java/org/apache/cassandra/net/async/RebufferingByteBufDataInputPlus.java b/src/java/org/apache/cassandra/net/async/RebufferingByteBufDataInputPlus.java deleted file mode 100644 index 4e667da83788..000000000000 --- a/src/java/org/apache/cassandra/net/async/RebufferingByteBufDataInputPlus.java +++ /dev/null @@ -1,291 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.cassandra.net.async; - -import java.io.EOFException; -import java.io.IOException; -import java.nio.ByteBuffer; -import java.nio.channels.ReadableByteChannel; -import java.util.concurrent.BlockingQueue; -import java.util.concurrent.LinkedBlockingQueue; -import java.util.concurrent.atomic.AtomicInteger; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufAllocator; -import io.netty.buffer.Unpooled; -import io.netty.channel.ChannelConfig; -import io.netty.util.ReferenceCountUtil; -import org.apache.cassandra.io.util.BufferedDataOutputStreamPlus; -import org.apache.cassandra.io.util.RebufferingInputStream; - -public class RebufferingByteBufDataInputPlus extends RebufferingInputStream implements ReadableByteChannel -{ - /** - * The parent, or owning, buffer of the current buffer being read from ({@link super#buffer}). - */ - private ByteBuf currentBuf; - - private final BlockingQueue queue; - - /** - * The count of live bytes in all {@link ByteBuf}s held by this instance. - */ - private final AtomicInteger queuedByteCount; - - private final int lowWaterMark; - private final int highWaterMark; - private final ChannelConfig channelConfig; - - private volatile boolean closed; - - public RebufferingByteBufDataInputPlus(int lowWaterMark, int highWaterMark, ChannelConfig channelConfig) - { - super(Unpooled.EMPTY_BUFFER.nioBuffer()); - - if (lowWaterMark > highWaterMark) - throw new IllegalArgumentException(String.format("low water mark is greater than high water mark: %d vs %d", lowWaterMark, highWaterMark)); - - currentBuf = Unpooled.EMPTY_BUFFER; - this.lowWaterMark = lowWaterMark; - this.highWaterMark = highWaterMark; - this.channelConfig = channelConfig; - queue = new LinkedBlockingQueue<>(); - queuedByteCount = new AtomicInteger(); - } - - /** - * Append a {@link ByteBuf} to the end of the einternal queue. - * - * Note: it's expected this method is invoked on the netty event loop. - */ - public void append(ByteBuf buf) throws IllegalStateException - { - assert buf != null : "buffer cannot be null"; - - if (closed) - { - ReferenceCountUtil.release(buf); - throw new IllegalStateException("stream is already closed, so cannot add another buffer"); - } - - // this slightly undercounts the live count as it doesn't include the currentBuf's size. - // that's ok as the worst we'll do is allow another buffer in and add it to the queue, - // and that point we'll disable auto-read. this is a tradeoff versus making some other member field - // atomic or volatile. - int queuedCount = queuedByteCount.addAndGet(buf.readableBytes()); - if (channelConfig.isAutoRead() && queuedCount > highWaterMark) - channelConfig.setAutoRead(false); - - queue.add(buf); - } - - /** - * {@inheritDoc} - * - * Release open buffers and poll the {@link #queue} for more data. - *

- * This is best, and more or less expected, to be invoked on a consuming thread (not the event loop) - * becasue if we block on the queue we can't fill it on the event loop (as that's where the buffers are coming from). - */ - @Override - protected void reBuffer() throws IOException - { - currentBuf.release(); - buffer = null; - currentBuf = null; - - // possibly re-enable auto-read, *before* blocking on the queue, because if we block on the queue - // without enabling auto-read we'll block forever :( - if (!channelConfig.isAutoRead() && queuedByteCount.get() < lowWaterMark) - channelConfig.setAutoRead(true); - - try - { - currentBuf = queue.take(); - int bytes; - // if we get an explicitly empty buffer, we treat that as an indicator that the input is closed - if (currentBuf == null || (bytes = currentBuf.readableBytes()) == 0) - { - releaseResources(); - throw new EOFException(); - } - - buffer = currentBuf.nioBuffer(currentBuf.readerIndex(), bytes); - assert buffer.remaining() == bytes; - queuedByteCount.addAndGet(-bytes); - return; - } - catch (InterruptedException ie) - { - // nop - ignore - } - } - - @Override - public int read(ByteBuffer dst) throws IOException - { - int readLength = dst.remaining(); - int remaining = readLength; - - while (remaining > 0) - { - if (closed) - throw new EOFException(); - - if (!buffer.hasRemaining()) - reBuffer(); - int copyLength = Math.min(remaining, buffer.remaining()); - - int originalLimit = buffer.limit(); - buffer.limit(buffer.position() + copyLength); - dst.put(buffer); - buffer.limit(originalLimit); - remaining -= copyLength; - } - - return readLength; - } - - /** - * {@inheritDoc} - * - * As long as this method is invoked on the consuming thread the returned value will be accurate. - * - * @throws EOFException thrown when no bytes are buffered and {@link #closed} is true. - */ - @Override - public int available() throws EOFException - { - final int availableBytes = queuedByteCount.get() + (buffer != null ? buffer.remaining() : 0); - - if (availableBytes == 0 && closed) - throw new EOFException(); - - if (!channelConfig.isAutoRead() && availableBytes < lowWaterMark) - channelConfig.setAutoRead(true); - - return availableBytes; - } - - @Override - public boolean isOpen() - { - return !closed; - } - - /** - * {@inheritDoc} - * - * Note: This should invoked on the consuming thread. - */ - @Override - public void close() - { - closed = true; - releaseResources(); - } - - private void releaseResources() - { - if (currentBuf != null) - { - if (currentBuf.refCnt() > 0) - currentBuf.release(currentBuf.refCnt()); - currentBuf = null; - buffer = null; - } - - ByteBuf buf; - while ((buf = queue.poll()) != null && buf.refCnt() > 0) - buf.release(buf.refCnt()); - } - - /** - * Mark this stream as closed, but do not release any of the resources. - * - * Note: this is best to be called from the producer thread. - */ - public void markClose() - { - if (!closed) - { - closed = true; - queue.add(Unpooled.EMPTY_BUFFER); - } - } - - /** - * {@inheritDoc} - * - * Note: this is best to be called from the consumer thread. - */ - @Override - public String toString() - { - return new StringBuilder(128).append("RebufferingByteBufDataInputPlus: currentBuf = ").append(currentBuf) - .append(" (super.buffer = ").append(buffer).append(')') - .append(", queuedByteCount = ").append(queuedByteCount) - .append(", queue buffers = ").append(queue) - .append(", closed = ").append(closed) - .toString(); - } - - public ByteBufAllocator getAllocator() - { - return channelConfig.getAllocator(); - } - - /** - * Consumes bytes in the stream until the given length - * - * @param writer - * @param len - * @return - * @throws IOException - */ - public long consumeUntil(BufferedDataOutputStreamPlus writer, long len) throws IOException - { - long copied = 0; // number of bytes copied - while (copied < len) - { - if (buffer.remaining() == 0) - { - try - { - reBuffer(); - } - catch (EOFException e) - { - throw new EOFException("EOF after " + copied + " bytes out of " + len); - } - if (buffer.remaining() == 0 && copied < len) - throw new AssertionError("reBuffer() failed to return data"); - } - - int originalLimit = buffer.limit(); - int toCopy = (int) Math.min(len - copied, buffer.remaining()); - buffer.limit(buffer.position() + toCopy); - int written = writer.applyToChannel(c -> c.write(buffer)); - buffer.limit(originalLimit); - copied += written; - } - - return copied; - } -} diff --git a/src/java/org/apache/cassandra/repair/AsymmetricRemoteSyncTask.java b/src/java/org/apache/cassandra/repair/AsymmetricRemoteSyncTask.java index d2a6aebd5777..cf6d84b5af20 100644 --- a/src/java/org/apache/cassandra/repair/AsymmetricRemoteSyncTask.java +++ b/src/java/org/apache/cassandra/repair/AsymmetricRemoteSyncTask.java @@ -24,14 +24,15 @@ import org.apache.cassandra.dht.Token; import org.apache.cassandra.exceptions.RepairException; import org.apache.cassandra.locator.InetAddressAndPort; +import org.apache.cassandra.net.Message; import org.apache.cassandra.net.MessagingService; import org.apache.cassandra.repair.messages.AsymmetricSyncRequest; -import org.apache.cassandra.repair.messages.SyncRequest; import org.apache.cassandra.streaming.PreviewKind; import org.apache.cassandra.streaming.SessionSummary; import org.apache.cassandra.tracing.Tracing; import org.apache.cassandra.utils.FBUtilities; -import org.apache.cassandra.utils.MerkleTrees; + +import static org.apache.cassandra.net.Verb.ASYMMETRIC_SYNC_REQ; /** * AsymmetricRemoteSyncTask sends {@link AsymmetricSyncRequest} to target node to repair(stream) @@ -52,7 +53,7 @@ public void startSync() AsymmetricSyncRequest request = new AsymmetricSyncRequest(desc, local, nodePair.coordinator, nodePair.peer, rangesToSync, previewKind); String message = String.format("Forwarding streaming repair of %d ranges to %s (to be streamed with %s)", request.ranges.size(), request.fetchingNode, request.fetchFrom); Tracing.traceRepair(message); - MessagingService.instance().sendOneWay(request.createMessage(), request.fetchingNode); + MessagingService.instance().send(Message.out(ASYMMETRIC_SYNC_REQ, request), request.fetchingNode); } public void syncComplete(boolean success, List summaries) diff --git a/src/java/org/apache/cassandra/repair/RepairJob.java b/src/java/org/apache/cassandra/repair/RepairJob.java index a67aac0ec708..f682bfb164d5 100644 --- a/src/java/org/apache/cassandra/repair/RepairJob.java +++ b/src/java/org/apache/cassandra/repair/RepairJob.java @@ -236,7 +236,10 @@ else if (isTransient.test(r1.endpoint) || isTransient.test(r2.endpoint)) } syncTasks.add(task); } + trees.get(i).trees.release(); } + trees.get(trees.size() - 1).trees.release(); + return syncTasks; } diff --git a/src/java/org/apache/cassandra/repair/RepairJobDesc.java b/src/java/org/apache/cassandra/repair/RepairJobDesc.java index 7e7de0738fbe..4aaf655b8258 100644 --- a/src/java/org/apache/cassandra/repair/RepairJobDesc.java +++ b/src/java/org/apache/cassandra/repair/RepairJobDesc.java @@ -26,6 +26,7 @@ import org.apache.cassandra.db.TypeSizes; import org.apache.cassandra.dht.AbstractBounds; +import org.apache.cassandra.dht.IPartitioner; import org.apache.cassandra.dht.Range; import org.apache.cassandra.dht.Token; import org.apache.cassandra.io.IVersionedSerializer; @@ -106,7 +107,7 @@ public void serialize(RepairJobDesc desc, DataOutputPlus out, int version) throw UUIDSerializer.serializer.serialize(desc.sessionId, out, version); out.writeUTF(desc.keyspace); out.writeUTF(desc.columnFamily); - MessagingService.validatePartitioner(desc.ranges); + IPartitioner.validate(desc.ranges); out.writeInt(desc.ranges.size()); for (Range rt : desc.ranges) AbstractBounds.tokenSerializer.serialize(rt, out, version); @@ -128,7 +129,7 @@ public RepairJobDesc deserialize(DataInputPlus in, int version) throws IOExcepti for (int i = 0; i < nRanges; i++) { range = (Range) AbstractBounds.tokenSerializer.deserialize(in, - MessagingService.globalPartitioner(), version); + IPartitioner.global(), version); ranges.add(range); } diff --git a/src/java/org/apache/cassandra/repair/RepairMessageVerbHandler.java b/src/java/org/apache/cassandra/repair/RepairMessageVerbHandler.java index 1e92a81165ab..27ffd05926e7 100644 --- a/src/java/org/apache/cassandra/repair/RepairMessageVerbHandler.java +++ b/src/java/org/apache/cassandra/repair/RepairMessageVerbHandler.java @@ -23,17 +23,17 @@ import org.slf4j.LoggerFactory; import org.apache.cassandra.db.ColumnFamilyStore; -import org.apache.cassandra.locator.InetAddressAndPort; +import org.apache.cassandra.exceptions.RequestFailureReason; import org.apache.cassandra.net.IVerbHandler; -import org.apache.cassandra.net.MessageIn; -import org.apache.cassandra.net.MessageOut; +import org.apache.cassandra.net.Message; import org.apache.cassandra.net.MessagingService; -import org.apache.cassandra.net.ParameterType; import org.apache.cassandra.repair.messages.*; import org.apache.cassandra.schema.TableId; import org.apache.cassandra.service.ActiveRepairService; import org.apache.cassandra.streaming.PreviewKind; +import static org.apache.cassandra.net.Verb.VALIDATION_RSP; + /** * Handles all repair related message. * @@ -41,6 +41,8 @@ */ public class RepairMessageVerbHandler implements IVerbHandler { + public static RepairMessageVerbHandler instance = new RepairMessageVerbHandler(); + private static final Logger logger = LoggerFactory.getLogger(RepairMessageVerbHandler.class); private boolean isIncremental(UUID sessionID) @@ -54,15 +56,15 @@ private PreviewKind previewKind(UUID sessionID) return prs != null ? prs.previewKind : PreviewKind.NONE; } - public void doVerb(final MessageIn message, final int id) + public void doVerb(final Message message) { // TODO add cancel/interrupt message RepairJobDesc desc = message.payload.desc; try { - switch (message.payload.messageType) + switch (message.verb()) { - case PREPARE_MESSAGE: + case PREPARE_MSG: PrepareMessage prepareMessage = (PrepareMessage) message.payload; logger.debug("Preparing, {}", prepareMessage); List columnFamilyStores = new ArrayList<>(prepareMessage.tableIds.size()); @@ -72,29 +74,29 @@ public void doVerb(final MessageIn message, final int id) if (columnFamilyStore == null) { logErrorAndSendFailureResponse(String.format("Table with id %s was dropped during prepare phase of repair", - tableId), message.from, id); + tableId), message); return; } columnFamilyStores.add(columnFamilyStore); } ActiveRepairService.instance.registerParentRepairSession(prepareMessage.parentRepairSession, - message.from, + message.from(), columnFamilyStores, prepareMessage.ranges, prepareMessage.isIncremental, prepareMessage.timestamp, prepareMessage.isGlobal, prepareMessage.previewKind); - MessagingService.instance().sendReply(new MessageOut(MessagingService.Verb.INTERNAL_RESPONSE), id, message.from); + MessagingService.instance().send(message.emptyResponse(), message.from()); break; - case SNAPSHOT: + case SNAPSHOT_MSG: logger.debug("Snapshotting {}", desc); final ColumnFamilyStore cfs = ColumnFamilyStore.getIfExists(desc.keyspace, desc.columnFamily); if (cfs == null) { logErrorAndSendFailureResponse(String.format("Table %s.%s was dropped during snapshot phase of repair", - desc.keyspace, desc.columnFamily), message.from, id); + desc.keyspace, desc.columnFamily), message); return; } @@ -108,11 +110,11 @@ public void doVerb(final MessageIn message, final int id) { repairManager.snapshot(desc.parentSessionId.toString(), desc.ranges, true); } - logger.debug("Enqueuing response to snapshot request {} to {}", desc.sessionId, message.from); - MessagingService.instance().sendReply(new MessageOut(MessagingService.Verb.INTERNAL_RESPONSE), id, message.from); + logger.debug("Enqueuing response to snapshot request {} to {}", desc.sessionId, message.from()); + MessagingService.instance().send(message.emptyResponse(), message.from()); break; - case VALIDATION_REQUEST: + case VALIDATION_REQ: ValidationRequest validationRequest = (ValidationRequest) message.payload; logger.debug("Validating {}", validationRequest); // trigger read-only compaction @@ -120,17 +122,17 @@ public void doVerb(final MessageIn message, final int id) if (store == null) { logger.error("Table {}.{} was dropped during snapshot phase of repair", desc.keyspace, desc.columnFamily); - MessagingService.instance().sendOneWay(new ValidationComplete(desc).createMessage(), message.from); + MessagingService.instance().send(Message.out(VALIDATION_RSP, new ValidationResponse(desc)), message.from()); return; } ActiveRepairService.instance.consistent.local.maybeSetRepairing(desc.parentSessionId); - Validator validator = new Validator(desc, message.from, validationRequest.nowInSec, + Validator validator = new Validator(desc, message.from(), validationRequest.nowInSec, isIncremental(desc.parentSessionId), previewKind(desc.parentSessionId)); ValidationManager.instance.submitValidation(store, validator); break; - case SYNC_REQUEST: + case SYNC_REQ: // forwarded sync request SyncRequest request = (SyncRequest) message.payload; logger.debug("Syncing {}", request); @@ -145,7 +147,7 @@ public void doVerb(final MessageIn message, final int id) task.run(); break; - case ASYMMETRIC_SYNC_REQUEST: + case ASYMMETRIC_SYNC_REQ: // forwarded sync request AsymmetricSyncRequest asymmetricSyncRequest = (AsymmetricSyncRequest) message.payload; logger.debug("Syncing {}", asymmetricSyncRequest); @@ -160,49 +162,49 @@ public void doVerb(final MessageIn message, final int id) asymmetricTask.run(); break; - case CLEANUP: + case CLEANUP_MSG: logger.debug("cleaning up repair"); CleanupMessage cleanup = (CleanupMessage) message.payload; ActiveRepairService.instance.removeParentRepairSession(cleanup.parentRepairSession); - MessagingService.instance().sendReply(new MessageOut(MessagingService.Verb.INTERNAL_RESPONSE), id, message.from); + MessagingService.instance().send(message.emptyResponse(), message.from()); break; - case CONSISTENT_REQUEST: - ActiveRepairService.instance.consistent.local.handlePrepareMessage(message.from, (PrepareConsistentRequest) message.payload); + case PREPARE_CONSISTENT_REQ: + ActiveRepairService.instance.consistent.local.handlePrepareMessage(message.from(), (PrepareConsistentRequest) message.payload); break; - case CONSISTENT_RESPONSE: + case PREPARE_CONSISTENT_RSP: ActiveRepairService.instance.consistent.coordinated.handlePrepareResponse((PrepareConsistentResponse) message.payload); break; - case FINALIZE_PROPOSE: - ActiveRepairService.instance.consistent.local.handleFinalizeProposeMessage(message.from, (FinalizePropose) message.payload); + case FINALIZE_PROPOSE_MSG: + ActiveRepairService.instance.consistent.local.handleFinalizeProposeMessage(message.from(), (FinalizePropose) message.payload); break; - case FINALIZE_PROMISE: + case FINALIZE_PROMISE_MSG: ActiveRepairService.instance.consistent.coordinated.handleFinalizePromiseMessage((FinalizePromise) message.payload); break; - case FINALIZE_COMMIT: - ActiveRepairService.instance.consistent.local.handleFinalizeCommitMessage(message.from, (FinalizeCommit) message.payload); + case FINALIZE_COMMIT_MSG: + ActiveRepairService.instance.consistent.local.handleFinalizeCommitMessage(message.from(), (FinalizeCommit) message.payload); break; - case FAILED_SESSION: + case FAILED_SESSION_MSG: FailSession failure = (FailSession) message.payload; ActiveRepairService.instance.consistent.coordinated.handleFailSessionMessage(failure); - ActiveRepairService.instance.consistent.local.handleFailSessionMessage(message.from, failure); + ActiveRepairService.instance.consistent.local.handleFailSessionMessage(message.from(), failure); break; - case STATUS_REQUEST: - ActiveRepairService.instance.consistent.local.handleStatusRequest(message.from, (StatusRequest) message.payload); + case STATUS_REQ: + ActiveRepairService.instance.consistent.local.handleStatusRequest(message.from(), (StatusRequest) message.payload); break; - case STATUS_RESPONSE: - ActiveRepairService.instance.consistent.local.handleStatusResponse(message.from, (StatusResponse) message.payload); + case STATUS_RSP: + ActiveRepairService.instance.consistent.local.handleStatusResponse(message.from(), (StatusResponse) message.payload); break; default: - ActiveRepairService.instance.handleMessage(message.from, message.payload); + ActiveRepairService.instance.handleMessage(message); break; } } @@ -215,11 +217,10 @@ public void doVerb(final MessageIn message, final int id) } } - private void logErrorAndSendFailureResponse(String errorMessage, InetAddressAndPort to, int id) + private void logErrorAndSendFailureResponse(String errorMessage, Message respondTo) { logger.error(errorMessage); - MessageOut reply = new MessageOut(MessagingService.Verb.INTERNAL_RESPONSE) - .withParameter(ParameterType.FAILURE_RESPONSE, MessagingService.ONE_BYTE); - MessagingService.instance().sendReply(reply, id, to); + Message reply = respondTo.failureResponse(RequestFailureReason.UNKNOWN); + MessagingService.instance().send(reply, respondTo.from()); } } diff --git a/src/java/org/apache/cassandra/repair/SnapshotTask.java b/src/java/org/apache/cassandra/repair/SnapshotTask.java index acc5186a5fea..40e4b3d09377 100644 --- a/src/java/org/apache/cassandra/repair/SnapshotTask.java +++ b/src/java/org/apache/cassandra/repair/SnapshotTask.java @@ -18,17 +18,18 @@ package org.apache.cassandra.repair; import java.util.concurrent.RunnableFuture; -import java.util.concurrent.TimeUnit; import com.google.common.util.concurrent.AbstractFuture; import org.apache.cassandra.exceptions.RequestFailureReason; import org.apache.cassandra.locator.InetAddressAndPort; -import org.apache.cassandra.net.IAsyncCallbackWithFailure; -import org.apache.cassandra.net.MessageIn; +import org.apache.cassandra.net.RequestCallback; +import org.apache.cassandra.net.Message; import org.apache.cassandra.net.MessagingService; import org.apache.cassandra.repair.messages.SnapshotMessage; +import static org.apache.cassandra.net.Verb.SNAPSHOT_MSG; + /** * SnapshotTask is a task that sends snapshot request. */ @@ -37,7 +38,7 @@ public class SnapshotTask extends AbstractFuture implements private final RepairJobDesc desc; private final InetAddressAndPort endpoint; - public SnapshotTask(RepairJobDesc desc, InetAddressAndPort endpoint) + SnapshotTask(RepairJobDesc desc, InetAddressAndPort endpoint) { this.desc = desc; this.endpoint = endpoint; @@ -45,15 +46,15 @@ public SnapshotTask(RepairJobDesc desc, InetAddressAndPort endpoint) public void run() { - MessagingService.instance().sendRR(new SnapshotMessage(desc).createMessage(), - endpoint, - new SnapshotCallback(this), TimeUnit.HOURS.toMillis(1), true); + MessagingService.instance().sendWithCallback(Message.out(SNAPSHOT_MSG, new SnapshotMessage(desc)), + endpoint, + new SnapshotCallback(this)); } /** * Callback for snapshot request. Run on INTERNAL_RESPONSE stage. */ - static class SnapshotCallback implements IAsyncCallbackWithFailure + static class SnapshotCallback implements RequestCallback { final SnapshotTask task; @@ -67,13 +68,19 @@ static class SnapshotCallback implements IAsyncCallbackWithFailure * * @param msg response received. */ - public void response(MessageIn msg) + @Override + public void onResponse(Message msg) { task.set(task.endpoint); } - public boolean isLatencyForSnitch() { return false; } + @Override + public boolean invokeOnFailure() + { + return true; + } + @Override public void onFailure(InetAddressAndPort from, RequestFailureReason failureReason) { //listener.failedSnapshot(); diff --git a/src/java/org/apache/cassandra/repair/StreamingRepairTask.java b/src/java/org/apache/cassandra/repair/StreamingRepairTask.java index e9cba8932327..827dce3256ed 100644 --- a/src/java/org/apache/cassandra/repair/StreamingRepairTask.java +++ b/src/java/org/apache/cassandra/repair/StreamingRepairTask.java @@ -29,8 +29,9 @@ import org.apache.cassandra.dht.Range; import org.apache.cassandra.dht.Token; import org.apache.cassandra.locator.InetAddressAndPort; +import org.apache.cassandra.net.Message; import org.apache.cassandra.net.MessagingService; -import org.apache.cassandra.repair.messages.SyncComplete; +import org.apache.cassandra.repair.messages.SyncResponse; import org.apache.cassandra.streaming.PreviewKind; import org.apache.cassandra.streaming.StreamEvent; import org.apache.cassandra.streaming.StreamEventHandler; @@ -38,9 +39,11 @@ import org.apache.cassandra.streaming.StreamState; import org.apache.cassandra.streaming.StreamOperation; +import static org.apache.cassandra.net.Verb.SYNC_RSP; + /** * StreamingRepairTask performs data streaming between two remote replicas, neither of which is repair coordinator. - * Task will send {@link SyncComplete} message back to coordinator upon streaming completion. + * Task will send {@link SyncResponse} message back to coordinator upon streaming completion. */ public class StreamingRepairTask implements Runnable, StreamEventHandler { @@ -95,19 +98,19 @@ public void handleStreamEvent(StreamEvent event) } /** - * If we succeeded on both stream in and out, reply back to coordinator + * If we succeeded on both stream in and out, respond back to coordinator */ public void onSuccess(StreamState state) { logger.info("[repair #{}] streaming task succeed, returning response to {}", desc.sessionId, initiator); - MessagingService.instance().sendOneWay(new SyncComplete(desc, src, dst, true, state.createSummaries()).createMessage(), initiator); + MessagingService.instance().send(Message.out(SYNC_RSP, new SyncResponse(desc, src, dst, true, state.createSummaries())), initiator); } /** - * If we failed on either stream in or out, reply fail to coordinator + * If we failed on either stream in or out, respond fail to coordinator */ public void onFailure(Throwable t) { - MessagingService.instance().sendOneWay(new SyncComplete(desc, src, dst, false, Collections.emptyList()).createMessage(), initiator); + MessagingService.instance().send(Message.out(SYNC_RSP, new SyncResponse(desc, src, dst, false, Collections.emptyList())), initiator); } } diff --git a/src/java/org/apache/cassandra/repair/SymmetricRemoteSyncTask.java b/src/java/org/apache/cassandra/repair/SymmetricRemoteSyncTask.java index c731bc185ec5..181554a0b5e7 100644 --- a/src/java/org/apache/cassandra/repair/SymmetricRemoteSyncTask.java +++ b/src/java/org/apache/cassandra/repair/SymmetricRemoteSyncTask.java @@ -19,7 +19,6 @@ import java.util.List; -import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -28,6 +27,7 @@ import org.apache.cassandra.dht.Token; import org.apache.cassandra.exceptions.RepairException; import org.apache.cassandra.locator.InetAddressAndPort; +import org.apache.cassandra.net.Message; import org.apache.cassandra.net.MessagingService; import org.apache.cassandra.repair.messages.RepairMessage; import org.apache.cassandra.repair.messages.SyncRequest; @@ -35,7 +35,8 @@ import org.apache.cassandra.streaming.SessionSummary; import org.apache.cassandra.tracing.Tracing; import org.apache.cassandra.utils.FBUtilities; -import org.apache.cassandra.utils.MerkleTrees; + +import static org.apache.cassandra.net.Verb.SYNC_REQ; /** * SymmetricRemoteSyncTask sends {@link SyncRequest} to remote(non-coordinator) node @@ -52,9 +53,9 @@ public SymmetricRemoteSyncTask(RepairJobDesc desc, InetAddressAndPort r1, InetAd super(desc, r1, r2, differences, previewKind); } - void sendRequest(RepairMessage request, InetAddressAndPort to) + void sendRequest(SyncRequest request, InetAddressAndPort to) { - MessagingService.instance().sendOneWay(request.createMessage(), to); + MessagingService.instance().send(Message.out(SYNC_REQ, request), to); } @Override diff --git a/src/java/org/apache/cassandra/repair/SyncNodePair.java b/src/java/org/apache/cassandra/repair/SyncNodePair.java index b353eb39a729..e10ad5a374f9 100644 --- a/src/java/org/apache/cassandra/repair/SyncNodePair.java +++ b/src/java/org/apache/cassandra/repair/SyncNodePair.java @@ -25,7 +25,8 @@ import org.apache.cassandra.io.util.DataInputPlus; import org.apache.cassandra.io.util.DataOutputPlus; import org.apache.cassandra.locator.InetAddressAndPort; -import org.apache.cassandra.net.CompactEndpointSerializationHelper; + +import static org.apache.cassandra.locator.InetAddressAndPort.Serializer.inetAddressAndPortSerializer; /** * SyncNodePair is used for repair message body to indicate the pair of nodes. @@ -71,21 +72,21 @@ public static class NodePairSerializer implements IVersionedSerializer= version 2.0 - MessagingService.instance().sendOneWay(new ValidationComplete(desc).createMessage(), initiator); + respond(new ValidationResponse(desc)); } /** @@ -411,12 +399,51 @@ public void fail() */ public void run() { - // respond to the request that triggered this validation - if (!initiator.equals(FBUtilities.getBroadcastAddressAndPort())) + if (initiatorIsRemote()) { logger.info("{} Sending completed merkle tree to {} for {}.{}", previewKind.logPrefix(desc.sessionId), initiator, desc.keyspace, desc.columnFamily); Tracing.traceRepair("Sending completed merkle tree to {} for {}.{}", initiator, desc.keyspace, desc.columnFamily); } - MessagingService.instance().sendOneWay(new ValidationComplete(desc, trees).createMessage(), initiator); + else + { + logger.info("{} Local completed merkle tree for {} for {}.{}", previewKind.logPrefix(desc.sessionId), initiator, desc.keyspace, desc.columnFamily); + Tracing.traceRepair("Local completed merkle tree for {} for {}.{}", initiator, desc.keyspace, desc.columnFamily); + + } + respond(new ValidationResponse(desc, trees)); + } + + private boolean initiatorIsRemote() + { + return !FBUtilities.getBroadcastAddressAndPort().equals(initiator); + } + + private void respond(ValidationResponse response) + { + if (initiatorIsRemote()) + { + MessagingService.instance().send(Message.out(VALIDATION_RSP, response), initiator); + return; + } + + /* + * For local initiators, DO NOT send the message to self over loopback. This is a wasted ser/de loop + * and a ton of garbage. Instead, move the trees off heap and invoke message handler. We could do it + * directly, since this method will only be called from {@code Stage.ENTI_ENTROPY}, but we do instead + * execute a {@code Runnable} on the stage - in case that assumption ever changes by accident. + */ + StageManager.getStage(Stage.ANTI_ENTROPY).execute(() -> + { + ValidationResponse movedResponse = response; + try + { + movedResponse = response.tryMoveOffHeap(); + } + catch (IOException e) + { + logger.error("Failed to move local merkle tree for {} off heap", desc, e); + } + ActiveRepairService.instance.handleMessage(Message.out(VALIDATION_RSP, movedResponse)); + }); } } diff --git a/src/java/org/apache/cassandra/repair/asymmetric/DifferenceHolder.java b/src/java/org/apache/cassandra/repair/asymmetric/DifferenceHolder.java index c9b7ed7ccf7a..f85c2ebb05d9 100644 --- a/src/java/org/apache/cassandra/repair/asymmetric/DifferenceHolder.java +++ b/src/java/org/apache/cassandra/repair/asymmetric/DifferenceHolder.java @@ -51,9 +51,11 @@ public DifferenceHolder(List trees) TreeResponse r2 = trees.get(j); hd.add(r2.endpoint, MerkleTrees.difference(r1.trees, r2.trees)); } + r1.trees.release(); // and add them to the diff map diffBuilder.put(r1.endpoint, hd); } + trees.get(trees.size() - 1).trees.release(); differences = diffBuilder.build(); } diff --git a/src/java/org/apache/cassandra/repair/consistent/CoordinatorSession.java b/src/java/org/apache/cassandra/repair/consistent/CoordinatorSession.java index b921342ec78b..8f1759afd125 100644 --- a/src/java/org/apache/cassandra/repair/consistent/CoordinatorSession.java +++ b/src/java/org/apache/cassandra/repair/consistent/CoordinatorSession.java @@ -40,8 +40,9 @@ import org.slf4j.LoggerFactory; import org.apache.cassandra.locator.InetAddressAndPort; -import org.apache.cassandra.net.MessageOut; +import org.apache.cassandra.net.Message; import org.apache.cassandra.net.MessagingService; +import org.apache.cassandra.net.Verb; import org.apache.cassandra.repair.RepairSessionResult; import org.apache.cassandra.repair.messages.FailSession; import org.apache.cassandra.repair.messages.FinalizeCommit; @@ -138,11 +139,10 @@ synchronized boolean hasFailed() return getState() == State.FAILED || Iterables.any(participantStates.values(), v -> v == State.FAILED); } - protected void sendMessage(InetAddressAndPort destination, RepairMessage message) + protected void sendMessage(InetAddressAndPort destination, Message message) { - logger.trace("Sending {} to {}", message, destination); - MessageOut messageOut = new MessageOut(MessagingService.Verb.REPAIR_MESSAGE, message, RepairMessage.serializer); - MessagingService.instance().sendOneWay(messageOut, destination); + logger.trace("Sending {} to {}", message.payload, destination); + MessagingService.instance().send(message, destination); } public ListenableFuture prepare() @@ -150,7 +150,8 @@ public ListenableFuture prepare() Preconditions.checkArgument(allStates(State.PREPARING)); logger.debug("Beginning prepare phase of incremental repair session {}", sessionID); - PrepareConsistentRequest message = new PrepareConsistentRequest(sessionID, coordinator, participants); + Message message = + Message.out(Verb.PREPARE_CONSISTENT_REQ, new PrepareConsistentRequest(sessionID, coordinator, participants)); for (final InetAddressAndPort participant : participants) { sendMessage(participant, message); @@ -197,7 +198,7 @@ public synchronized ListenableFuture finalizePropose() { Preconditions.checkArgument(allStates(State.REPAIRING)); logger.debug("Proposing finalization of repair session {}", sessionID); - FinalizePropose message = new FinalizePropose(sessionID); + Message message = Message.out(Verb.FINALIZE_PROPOSE_MSG, new FinalizePropose(sessionID)); for (final InetAddressAndPort participant : participants) { sendMessage(participant, message); @@ -233,7 +234,7 @@ public synchronized void finalizeCommit() { Preconditions.checkArgument(allStates(State.FINALIZE_PROMISED)); logger.debug("Committing finalization of repair session {}", sessionID); - FinalizeCommit message = new FinalizeCommit(sessionID); + Message message = Message.out(Verb.FINALIZE_COMMIT_MSG, new FinalizeCommit(sessionID)); for (final InetAddressAndPort participant : participants) { sendMessage(participant, message); @@ -244,7 +245,7 @@ public synchronized void finalizeCommit() private void sendFailureMessageToParticipants() { - FailSession message = new FailSession(sessionID); + Message message = Message.out(Verb.FAILED_SESSION_MSG, new FailSession(sessionID)); for (final InetAddressAndPort participant : participants) { if (participantStates.get(participant) != State.FAILED) diff --git a/src/java/org/apache/cassandra/repair/consistent/LocalSessions.java b/src/java/org/apache/cassandra/repair/consistent/LocalSessions.java index c39c4e6f7626..b6103c4ea985 100644 --- a/src/java/org/apache/cassandra/repair/consistent/LocalSessions.java +++ b/src/java/org/apache/cassandra/repair/consistent/LocalSessions.java @@ -72,7 +72,7 @@ import org.apache.cassandra.gms.FailureDetector; import org.apache.cassandra.io.util.DataInputBuffer; import org.apache.cassandra.io.util.DataOutputBuffer; -import org.apache.cassandra.net.MessageOut; +import org.apache.cassandra.net.Message; import org.apache.cassandra.net.MessagingService; import org.apache.cassandra.repair.messages.FailSession; import org.apache.cassandra.repair.messages.FinalizeCommit; @@ -88,6 +88,11 @@ import org.apache.cassandra.service.StorageService; import org.apache.cassandra.utils.FBUtilities; +import static org.apache.cassandra.net.Verb.FAILED_SESSION_MSG; +import static org.apache.cassandra.net.Verb.FINALIZE_PROMISE_MSG; +import static org.apache.cassandra.net.Verb.PREPARE_CONSISTENT_RSP; +import static org.apache.cassandra.net.Verb.STATUS_REQ; +import static org.apache.cassandra.net.Verb.STATUS_RSP; import static org.apache.cassandra.repair.consistent.ConsistentSession.State.*; /** @@ -189,10 +194,11 @@ public void cancelSession(UUID sessionID, boolean force) sessionID, session.coordinator); setStateAndSave(session, FAILED); + Message message = Message.out(FAILED_SESSION_MSG, new FailSession(sessionID)); for (InetAddressAndPort participant : session.participants) { if (!participant.equals(getBroadcastAddressAndPort())) - sendMessage(participant, new FailSession(sessionID)); + sendMessage(participant, message); } } @@ -410,7 +416,7 @@ private void syncTable() { TableId tid = Schema.instance.getTableMetadata(keyspace, table).id; ColumnFamilyStore cfm = Schema.instance.getColumnFamilyStoreInstance(tid); - cfm.forceBlockingFlush(); + cfm.forceBlockingFlushToSSTable(); } /** @@ -449,7 +455,7 @@ synchronized void putSessionUnsafe(LocalSession session) private synchronized void putSession(LocalSession session) { Preconditions.checkArgument(!sessions.containsKey(session.sessionID), - "LocalSession {} already exists", session.sessionID); + "LocalSession %s already exists", session.sessionID); Preconditions.checkArgument(started, "sessions cannot be added before LocalSessions is started"); sessions = ImmutableMap.builder() .putAll(sessions) @@ -490,11 +496,10 @@ protected ActiveRepairService.ParentRepairSession getParentRepairSession(UUID se return ActiveRepairService.instance.getParentRepairSession(sessionID); } - protected void sendMessage(InetAddressAndPort destination, RepairMessage message) + protected void sendMessage(InetAddressAndPort destination, Message message) { - logger.trace("sending {} to {}", message, destination); - MessageOut messageOut = new MessageOut(MessagingService.Verb.REPAIR_MESSAGE, message, RepairMessage.serializer); - MessagingService.instance().sendOneWay(messageOut, destination); + logger.trace("sending {} to {}", message.payload, destination); + MessagingService.instance().send(message, destination); } private void setStateAndSave(LocalSession session, ConsistentSession.State state) @@ -537,7 +542,7 @@ public void failSession(UUID sessionID, boolean sendMessage) } if (sendMessage) { - sendMessage(session.coordinator, new FailSession(sessionID)); + sendMessage(session.coordinator, Message.out(FAILED_SESSION_MSG, new FailSession(sessionID))); } } } @@ -608,7 +613,7 @@ public void handlePrepareMessage(InetAddressAndPort from, PrepareConsistentReque catch (Throwable e) { logger.error("Error retrieving ParentRepairSession for session {}, responding with failure", sessionID); - sendMessage(coordinator, new PrepareConsistentResponse(sessionID, getBroadcastAddressAndPort(), false)); + sendMessage(coordinator, Message.out(PREPARE_CONSISTENT_RSP, new PrepareConsistentResponse(sessionID, getBroadcastAddressAndPort(), false))); return; } @@ -631,15 +636,14 @@ public void onSuccess(@Nullable Object result) { logger.info("Prepare phase for incremental repair session {} completed", sessionID); if (session.getState() != FAILED) - { setStateAndSave(session, PREPARED); - sendMessage(coordinator, new PrepareConsistentResponse(sessionID, getBroadcastAddressAndPort(), true)); - } else - { logger.info("Session {} failed before anticompaction completed", sessionID); - sendMessage(coordinator, new PrepareConsistentResponse(sessionID, getBroadcastAddressAndPort(), false)); - } + + Message message = + Message.out(PREPARE_CONSISTENT_RSP, + new PrepareConsistentResponse(sessionID, getBroadcastAddressAndPort(), session.getState() != FAILED)); + sendMessage(coordinator, message); } finally { @@ -652,7 +656,9 @@ public void onFailure(Throwable t) try { logger.error("Prepare phase for incremental repair session {} failed", sessionID, t); - sendMessage(coordinator, new PrepareConsistentResponse(sessionID, getBroadcastAddressAndPort(), false)); + sendMessage(coordinator, + Message.out(PREPARE_CONSISTENT_RSP, + new PrepareConsistentResponse(sessionID, getBroadcastAddressAndPort(), false))); failSession(sessionID, false); } finally @@ -681,7 +687,7 @@ public void handleFinalizeProposeMessage(InetAddressAndPort from, FinalizePropos if (session == null) { logger.debug("Received FinalizePropose message for unknown repair session {}, responding with failure", sessionID); - sendMessage(from, new FailSession(sessionID)); + sendMessage(from, Message.out(FAILED_SESSION_MSG, new FailSession(sessionID))); return; } @@ -698,7 +704,7 @@ public void handleFinalizeProposeMessage(InetAddressAndPort from, FinalizePropos */ syncTable(); - sendMessage(from, new FinalizePromise(sessionID, getBroadcastAddressAndPort(), true)); + sendMessage(from, Message.out(FINALIZE_PROMISE_MSG, new FinalizePromise(sessionID, getBroadcastAddressAndPort(), true))); logger.debug("Received FinalizePropose message for incremental repair session {}, responded with FinalizePromise", sessionID); } catch (IllegalArgumentException e) @@ -752,7 +758,8 @@ public void handleFailSessionMessage(InetAddressAndPort from, FailSession msg) public void sendStatusRequest(LocalSession session) { logger.debug("Attempting to learn the outcome of unfinished local incremental repair session {}", session.sessionID); - StatusRequest request = new StatusRequest(session.sessionID); + Message request = Message.out(STATUS_REQ, new StatusRequest(session.sessionID)); + for (InetAddressAndPort participant : session.participants) { if (!getBroadcastAddressAndPort().equals(participant) && isAlive(participant)) @@ -770,11 +777,11 @@ public void handleStatusRequest(InetAddressAndPort from, StatusRequest request) if (session == null) { logger.warn("Received status response message for unknown session {}", sessionID); - sendMessage(from, new StatusResponse(sessionID, FAILED)); + sendMessage(from, Message.out(STATUS_RSP, new StatusResponse(sessionID, FAILED))); } else { - sendMessage(from, new StatusResponse(sessionID, session.getState())); + sendMessage(from, Message.out(STATUS_RSP, new StatusResponse(sessionID, session.getState()))); logger.debug("Responding to status response message for incremental repair session {} with local state {}", sessionID, session.getState()); } } diff --git a/src/java/org/apache/cassandra/repair/messages/AsymmetricSyncRequest.java b/src/java/org/apache/cassandra/repair/messages/AsymmetricSyncRequest.java index 6d7626972c30..eacc285aa8f5 100644 --- a/src/java/org/apache/cassandra/repair/messages/AsymmetricSyncRequest.java +++ b/src/java/org/apache/cassandra/repair/messages/AsymmetricSyncRequest.java @@ -18,7 +18,6 @@ package org.apache.cassandra.repair.messages; import java.io.IOException; -import java.net.InetAddress; import java.util.ArrayList; import java.util.Collection; import java.util.List; @@ -26,20 +25,20 @@ import org.apache.cassandra.db.TypeSizes; import org.apache.cassandra.dht.AbstractBounds; +import org.apache.cassandra.dht.IPartitioner; import org.apache.cassandra.dht.Range; import org.apache.cassandra.dht.Token; +import org.apache.cassandra.io.IVersionedSerializer; import org.apache.cassandra.io.util.DataInputPlus; import org.apache.cassandra.io.util.DataOutputPlus; import org.apache.cassandra.locator.InetAddressAndPort; -import org.apache.cassandra.net.CompactEndpointSerializationHelper; -import org.apache.cassandra.net.MessagingService; import org.apache.cassandra.repair.RepairJobDesc; import org.apache.cassandra.streaming.PreviewKind; +import static org.apache.cassandra.locator.InetAddressAndPort.Serializer.inetAddressAndPortSerializer; + public class AsymmetricSyncRequest extends RepairMessage { - public static MessageSerializer serializer = new SyncRequestSerializer(); - public final InetAddressAndPort initiator; public final InetAddressAndPort fetchingNode; public final InetAddressAndPort fetchFrom; @@ -48,7 +47,7 @@ public class AsymmetricSyncRequest extends RepairMessage public AsymmetricSyncRequest(RepairJobDesc desc, InetAddressAndPort initiator, InetAddressAndPort fetchingNode, InetAddressAndPort fetchFrom, Collection> ranges, PreviewKind previewKind) { - super(Type.ASYMMETRIC_SYNC_REQUEST, desc); + super(desc); this.initiator = initiator; this.fetchingNode = fetchingNode; this.fetchFrom = fetchFrom; @@ -62,8 +61,7 @@ public boolean equals(Object o) if (!(o instanceof AsymmetricSyncRequest)) return false; AsymmetricSyncRequest req = (AsymmetricSyncRequest)o; - return messageType == req.messageType && - desc.equals(req.desc) && + return desc.equals(req.desc) && initiator.equals(req.initiator) && fetchingNode.equals(req.fetchingNode) && fetchFrom.equals(req.fetchFrom) && @@ -73,21 +71,21 @@ public boolean equals(Object o) @Override public int hashCode() { - return Objects.hash(messageType, desc, initiator, fetchingNode, fetchFrom, ranges); + return Objects.hash(desc, initiator, fetchingNode, fetchFrom, ranges); } - public static class SyncRequestSerializer implements MessageSerializer + public static final IVersionedSerializer serializer = new IVersionedSerializer() { public void serialize(AsymmetricSyncRequest message, DataOutputPlus out, int version) throws IOException { RepairJobDesc.serializer.serialize(message.desc, out, version); - CompactEndpointSerializationHelper.instance.serialize(message.initiator, out, version); - CompactEndpointSerializationHelper.instance.serialize(message.fetchingNode, out, version); - CompactEndpointSerializationHelper.instance.serialize(message.fetchFrom, out, version); + inetAddressAndPortSerializer.serialize(message.initiator, out, version); + inetAddressAndPortSerializer.serialize(message.fetchingNode, out, version); + inetAddressAndPortSerializer.serialize(message.fetchFrom, out, version); out.writeInt(message.ranges.size()); for (Range range : message.ranges) { - MessagingService.validatePartitioner(range); + IPartitioner.validate(range); AbstractBounds.tokenSerializer.serialize(range, out, version); } out.writeInt(message.previewKind.getSerializationVal()); @@ -96,13 +94,13 @@ public void serialize(AsymmetricSyncRequest message, DataOutputPlus out, int ver public AsymmetricSyncRequest deserialize(DataInputPlus in, int version) throws IOException { RepairJobDesc desc = RepairJobDesc.serializer.deserialize(in, version); - InetAddressAndPort owner = CompactEndpointSerializationHelper.instance.deserialize(in, version); - InetAddressAndPort src = CompactEndpointSerializationHelper.instance.deserialize(in, version); - InetAddressAndPort dst = CompactEndpointSerializationHelper.instance.deserialize(in, version); + InetAddressAndPort owner = inetAddressAndPortSerializer.deserialize(in, version); + InetAddressAndPort src = inetAddressAndPortSerializer.deserialize(in, version); + InetAddressAndPort dst = inetAddressAndPortSerializer.deserialize(in, version); int rangesCount = in.readInt(); List> ranges = new ArrayList<>(rangesCount); for (int i = 0; i < rangesCount; ++i) - ranges.add((Range) AbstractBounds.tokenSerializer.deserialize(in, MessagingService.globalPartitioner(), version)); + ranges.add((Range) AbstractBounds.tokenSerializer.deserialize(in, IPartitioner.global(), version)); PreviewKind previewKind = PreviewKind.deserialize(in.readInt()); return new AsymmetricSyncRequest(desc, owner, src, dst, ranges, previewKind); } @@ -110,16 +108,16 @@ public AsymmetricSyncRequest deserialize(DataInputPlus in, int version) throws I public long serializedSize(AsymmetricSyncRequest message, int version) { long size = RepairJobDesc.serializer.serializedSize(message.desc, version); - size += CompactEndpointSerializationHelper.instance.serializedSize(message.initiator, version); - size += CompactEndpointSerializationHelper.instance.serializedSize(message.fetchingNode, version); - size += CompactEndpointSerializationHelper.instance.serializedSize(message.fetchFrom, version); + size += inetAddressAndPortSerializer.serializedSize(message.initiator, version); + size += inetAddressAndPortSerializer.serializedSize(message.fetchingNode, version); + size += inetAddressAndPortSerializer.serializedSize(message.fetchFrom, version); size += TypeSizes.sizeof(message.ranges.size()); for (Range range : message.ranges) size += AbstractBounds.tokenSerializer.serializedSize(range, version); size += TypeSizes.sizeof(message.previewKind.getSerializationVal()); return size; } - } + }; public String toString() { diff --git a/src/java/org/apache/cassandra/repair/messages/CleanupMessage.java b/src/java/org/apache/cassandra/repair/messages/CleanupMessage.java index 69d147a29db0..5ec7fc65d395 100644 --- a/src/java/org/apache/cassandra/repair/messages/CleanupMessage.java +++ b/src/java/org/apache/cassandra/repair/messages/CleanupMessage.java @@ -21,6 +21,7 @@ import java.util.Objects; import java.util.UUID; +import org.apache.cassandra.io.IVersionedSerializer; import org.apache.cassandra.io.util.DataInputPlus; import org.apache.cassandra.io.util.DataOutputPlus; import org.apache.cassandra.utils.UUIDSerializer; @@ -32,12 +33,11 @@ */ public class CleanupMessage extends RepairMessage { - public static MessageSerializer serializer = new CleanupMessageSerializer(); public final UUID parentRepairSession; public CleanupMessage(UUID parentRepairSession) { - super(Type.CLEANUP, null); + super(null); this.parentRepairSession = parentRepairSession; } @@ -47,17 +47,16 @@ public boolean equals(Object o) if (!(o instanceof CleanupMessage)) return false; CleanupMessage other = (CleanupMessage) o; - return messageType == other.messageType && - parentRepairSession.equals(other.parentRepairSession); + return parentRepairSession.equals(other.parentRepairSession); } @Override public int hashCode() { - return Objects.hash(messageType, parentRepairSession); + return Objects.hash(parentRepairSession); } - public static class CleanupMessageSerializer implements MessageSerializer + public static final IVersionedSerializer serializer = new IVersionedSerializer() { public void serialize(CleanupMessage message, DataOutputPlus out, int version) throws IOException { @@ -74,5 +73,5 @@ public long serializedSize(CleanupMessage message, int version) { return UUIDSerializer.serializer.serializedSize(message.parentRepairSession, version); } - } + }; } diff --git a/src/java/org/apache/cassandra/repair/messages/FailSession.java b/src/java/org/apache/cassandra/repair/messages/FailSession.java index 1227cc395e5d..b8c7ad34f798 100644 --- a/src/java/org/apache/cassandra/repair/messages/FailSession.java +++ b/src/java/org/apache/cassandra/repair/messages/FailSession.java @@ -21,6 +21,7 @@ import java.io.IOException; import java.util.UUID; +import org.apache.cassandra.io.IVersionedSerializer; import org.apache.cassandra.io.util.DataInputPlus; import org.apache.cassandra.io.util.DataOutputPlus; import org.apache.cassandra.utils.UUIDSerializer; @@ -31,7 +32,7 @@ public class FailSession extends RepairMessage public FailSession(UUID sessionID) { - super(Type.FAILED_SESSION, null); + super(null); assert sessionID != null; this.sessionID = sessionID; } @@ -51,7 +52,7 @@ public int hashCode() return sessionID.hashCode(); } - public static final MessageSerializer serializer = new MessageSerializer() + public static final IVersionedSerializer serializer = new IVersionedSerializer() { public void serialize(FailSession msg, DataOutputPlus out, int version) throws IOException { diff --git a/src/java/org/apache/cassandra/repair/messages/FinalizeCommit.java b/src/java/org/apache/cassandra/repair/messages/FinalizeCommit.java index a4eb111f7304..bb5cca72b012 100644 --- a/src/java/org/apache/cassandra/repair/messages/FinalizeCommit.java +++ b/src/java/org/apache/cassandra/repair/messages/FinalizeCommit.java @@ -21,6 +21,7 @@ import java.io.IOException; import java.util.UUID; +import org.apache.cassandra.io.IVersionedSerializer; import org.apache.cassandra.io.util.DataInputPlus; import org.apache.cassandra.io.util.DataOutputPlus; import org.apache.cassandra.utils.UUIDSerializer; @@ -31,7 +32,7 @@ public class FinalizeCommit extends RepairMessage public FinalizeCommit(UUID sessionID) { - super(Type.FINALIZE_COMMIT, null); + super(null); assert sessionID != null; this.sessionID = sessionID; } @@ -58,7 +59,7 @@ public String toString() '}'; } - public static MessageSerializer serializer = new MessageSerializer() + public static final IVersionedSerializer serializer = new IVersionedSerializer() { public void serialize(FinalizeCommit msg, DataOutputPlus out, int version) throws IOException { diff --git a/src/java/org/apache/cassandra/repair/messages/FinalizePromise.java b/src/java/org/apache/cassandra/repair/messages/FinalizePromise.java index 449748ab45a4..cfdc07c3825e 100644 --- a/src/java/org/apache/cassandra/repair/messages/FinalizePromise.java +++ b/src/java/org/apache/cassandra/repair/messages/FinalizePromise.java @@ -22,12 +22,14 @@ import java.util.UUID; import org.apache.cassandra.db.TypeSizes; +import org.apache.cassandra.io.IVersionedSerializer; import org.apache.cassandra.io.util.DataInputPlus; import org.apache.cassandra.io.util.DataOutputPlus; import org.apache.cassandra.locator.InetAddressAndPort; -import org.apache.cassandra.net.CompactEndpointSerializationHelper; import org.apache.cassandra.utils.UUIDSerializer; +import static org.apache.cassandra.locator.InetAddressAndPort.Serializer.inetAddressAndPortSerializer; + public class FinalizePromise extends RepairMessage { public final UUID sessionID; @@ -36,7 +38,7 @@ public class FinalizePromise extends RepairMessage public FinalizePromise(UUID sessionID, InetAddressAndPort participant, boolean promised) { - super(Type.FINALIZE_PROMISE, null); + super(null); assert sessionID != null; assert participant != null; this.sessionID = sessionID; @@ -64,26 +66,26 @@ public int hashCode() return result; } - public static MessageSerializer serializer = new MessageSerializer() + public static final IVersionedSerializer serializer = new IVersionedSerializer() { public void serialize(FinalizePromise msg, DataOutputPlus out, int version) throws IOException { UUIDSerializer.serializer.serialize(msg.sessionID, out, version); - CompactEndpointSerializationHelper.instance.serialize(msg.participant, out, version); + inetAddressAndPortSerializer.serialize(msg.participant, out, version); out.writeBoolean(msg.promised); } public FinalizePromise deserialize(DataInputPlus in, int version) throws IOException { return new FinalizePromise(UUIDSerializer.serializer.deserialize(in, version), - CompactEndpointSerializationHelper.instance.deserialize(in, version), + inetAddressAndPortSerializer.deserialize(in, version), in.readBoolean()); } public long serializedSize(FinalizePromise msg, int version) { long size = UUIDSerializer.serializer.serializedSize(msg.sessionID, version); - size += CompactEndpointSerializationHelper.instance.serializedSize(msg.participant, version); + size += inetAddressAndPortSerializer.serializedSize(msg.participant, version); size += TypeSizes.sizeof(msg.promised); return size; } diff --git a/src/java/org/apache/cassandra/repair/messages/FinalizePropose.java b/src/java/org/apache/cassandra/repair/messages/FinalizePropose.java index c0c49df72177..c21dd78b9fe4 100644 --- a/src/java/org/apache/cassandra/repair/messages/FinalizePropose.java +++ b/src/java/org/apache/cassandra/repair/messages/FinalizePropose.java @@ -21,6 +21,7 @@ import java.io.IOException; import java.util.UUID; +import org.apache.cassandra.io.IVersionedSerializer; import org.apache.cassandra.io.util.DataInputPlus; import org.apache.cassandra.io.util.DataOutputPlus; import org.apache.cassandra.utils.UUIDSerializer; @@ -31,7 +32,7 @@ public class FinalizePropose extends RepairMessage public FinalizePropose(UUID sessionID) { - super(Type.FINALIZE_PROPOSE, null); + super(null); assert sessionID != null; this.sessionID = sessionID; } @@ -58,7 +59,7 @@ public String toString() '}'; } - public static MessageSerializer serializer = new MessageSerializer() + public static final IVersionedSerializer serializer = new IVersionedSerializer() { public void serialize(FinalizePropose msg, DataOutputPlus out, int version) throws IOException { diff --git a/src/java/org/apache/cassandra/repair/messages/PrepareConsistentRequest.java b/src/java/org/apache/cassandra/repair/messages/PrepareConsistentRequest.java index 9aae256e5a80..c1be082464c4 100644 --- a/src/java/org/apache/cassandra/repair/messages/PrepareConsistentRequest.java +++ b/src/java/org/apache/cassandra/repair/messages/PrepareConsistentRequest.java @@ -26,12 +26,14 @@ import com.google.common.collect.ImmutableSet; import org.apache.cassandra.db.TypeSizes; +import org.apache.cassandra.io.IVersionedSerializer; import org.apache.cassandra.io.util.DataInputPlus; import org.apache.cassandra.io.util.DataOutputPlus; import org.apache.cassandra.locator.InetAddressAndPort; -import org.apache.cassandra.net.CompactEndpointSerializationHelper; import org.apache.cassandra.utils.UUIDSerializer; +import static org.apache.cassandra.locator.InetAddressAndPort.Serializer.inetAddressAndPortSerializer; + public class PrepareConsistentRequest extends RepairMessage { public final UUID parentSession; @@ -40,7 +42,7 @@ public class PrepareConsistentRequest extends RepairMessage public PrepareConsistentRequest(UUID parentSession, InetAddressAndPort coordinator, Set participants) { - super(Type.CONSISTENT_REQUEST, null); + super(null); assert parentSession != null; assert coordinator != null; assert participants != null && !participants.isEmpty(); @@ -78,29 +80,28 @@ public String toString() '}'; } - public static MessageSerializer serializer = new MessageSerializer() + public static final IVersionedSerializer serializer = new IVersionedSerializer() { - public void serialize(PrepareConsistentRequest request, DataOutputPlus out, int version) throws IOException { UUIDSerializer.serializer.serialize(request.parentSession, out, version); - CompactEndpointSerializationHelper.instance.serialize(request.coordinator, out, version); + inetAddressAndPortSerializer.serialize(request.coordinator, out, version); out.writeInt(request.participants.size()); for (InetAddressAndPort peer : request.participants) { - CompactEndpointSerializationHelper.instance.serialize(peer, out, version); + inetAddressAndPortSerializer.serialize(peer, out, version); } } public PrepareConsistentRequest deserialize(DataInputPlus in, int version) throws IOException { UUID sessionId = UUIDSerializer.serializer.deserialize(in, version); - InetAddressAndPort coordinator = CompactEndpointSerializationHelper.instance.deserialize(in, version); + InetAddressAndPort coordinator = inetAddressAndPortSerializer.deserialize(in, version); int numPeers = in.readInt(); Set peers = new HashSet<>(numPeers); for (int i = 0; i < numPeers; i++) { - InetAddressAndPort peer = CompactEndpointSerializationHelper.instance.deserialize(in, version); + InetAddressAndPort peer = inetAddressAndPortSerializer.deserialize(in, version); peers.add(peer); } return new PrepareConsistentRequest(sessionId, coordinator, peers); @@ -109,11 +110,11 @@ public PrepareConsistentRequest deserialize(DataInputPlus in, int version) throw public long serializedSize(PrepareConsistentRequest request, int version) { long size = UUIDSerializer.serializer.serializedSize(request.parentSession, version); - size += CompactEndpointSerializationHelper.instance.serializedSize(request.coordinator, version); + size += inetAddressAndPortSerializer.serializedSize(request.coordinator, version); size += TypeSizes.sizeof(request.participants.size()); for (InetAddressAndPort peer : request.participants) { - size += CompactEndpointSerializationHelper.instance.serializedSize(peer, version); + size += inetAddressAndPortSerializer.serializedSize(peer, version); } return size; } diff --git a/src/java/org/apache/cassandra/repair/messages/PrepareConsistentResponse.java b/src/java/org/apache/cassandra/repair/messages/PrepareConsistentResponse.java index 630f18efda42..00de77dab385 100644 --- a/src/java/org/apache/cassandra/repair/messages/PrepareConsistentResponse.java +++ b/src/java/org/apache/cassandra/repair/messages/PrepareConsistentResponse.java @@ -22,12 +22,14 @@ import java.util.UUID; import org.apache.cassandra.db.TypeSizes; +import org.apache.cassandra.io.IVersionedSerializer; import org.apache.cassandra.io.util.DataInputPlus; import org.apache.cassandra.io.util.DataOutputPlus; import org.apache.cassandra.locator.InetAddressAndPort; -import org.apache.cassandra.net.CompactEndpointSerializationHelper; import org.apache.cassandra.utils.UUIDSerializer; +import static org.apache.cassandra.locator.InetAddressAndPort.Serializer.inetAddressAndPortSerializer; + public class PrepareConsistentResponse extends RepairMessage { public final UUID parentSession; @@ -36,7 +38,7 @@ public class PrepareConsistentResponse extends RepairMessage public PrepareConsistentResponse(UUID parentSession, InetAddressAndPort participant, boolean success) { - super(Type.CONSISTENT_RESPONSE, null); + super(null); assert parentSession != null; assert participant != null; this.parentSession = parentSession; @@ -64,26 +66,26 @@ public int hashCode() return result; } - public static MessageSerializer serializer = new MessageSerializer() + public static final IVersionedSerializer serializer = new IVersionedSerializer() { public void serialize(PrepareConsistentResponse response, DataOutputPlus out, int version) throws IOException { UUIDSerializer.serializer.serialize(response.parentSession, out, version); - CompactEndpointSerializationHelper.instance.serialize(response.participant, out, version); + inetAddressAndPortSerializer.serialize(response.participant, out, version); out.writeBoolean(response.success); } public PrepareConsistentResponse deserialize(DataInputPlus in, int version) throws IOException { return new PrepareConsistentResponse(UUIDSerializer.serializer.deserialize(in, version), - CompactEndpointSerializationHelper.instance.deserialize(in, version), + inetAddressAndPortSerializer.deserialize(in, version), in.readBoolean()); } public long serializedSize(PrepareConsistentResponse response, int version) { long size = UUIDSerializer.serializer.serializedSize(response.parentSession, version); - size += CompactEndpointSerializationHelper.instance.serializedSize(response.participant, version); + size += inetAddressAndPortSerializer.serializedSize(response.participant, version); size += TypeSizes.sizeof(response.success); return size; } diff --git a/src/java/org/apache/cassandra/repair/messages/PrepareMessage.java b/src/java/org/apache/cassandra/repair/messages/PrepareMessage.java index 4d59942b8d78..9c485bc2f13d 100644 --- a/src/java/org/apache/cassandra/repair/messages/PrepareMessage.java +++ b/src/java/org/apache/cassandra/repair/messages/PrepareMessage.java @@ -24,9 +24,13 @@ import java.util.Objects; import java.util.UUID; +import com.google.common.base.Preconditions; + import org.apache.cassandra.db.TypeSizes; +import org.apache.cassandra.dht.IPartitioner; import org.apache.cassandra.dht.Range; import org.apache.cassandra.dht.Token; +import org.apache.cassandra.io.IVersionedSerializer; import org.apache.cassandra.io.util.DataInputPlus; import org.apache.cassandra.io.util.DataOutputPlus; import org.apache.cassandra.net.MessagingService; @@ -37,7 +41,6 @@ public class PrepareMessage extends RepairMessage { - public final static MessageSerializer serializer = new PrepareMessageSerializer(); public final List tableIds; public final Collection> ranges; @@ -49,7 +52,7 @@ public class PrepareMessage extends RepairMessage public PrepareMessage(UUID parentRepairSession, List tableIds, Collection> ranges, boolean isIncremental, long timestamp, boolean isGlobal, PreviewKind previewKind) { - super(Type.PREPARE_MESSAGE, null); + super(null); this.parentRepairSession = parentRepairSession; this.tableIds = tableIds; this.ranges = ranges; @@ -65,8 +68,7 @@ public boolean equals(Object o) if (!(o instanceof PrepareMessage)) return false; PrepareMessage other = (PrepareMessage) o; - return messageType == other.messageType && - parentRepairSession.equals(other.parentRepairSession) && + return parentRepairSession.equals(other.parentRepairSession) && isIncremental == other.isIncremental && isGlobal == other.isGlobal && previewKind == other.previewKind && @@ -78,13 +80,18 @@ public boolean equals(Object o) @Override public int hashCode() { - return Objects.hash(messageType, parentRepairSession, isGlobal, previewKind, isIncremental, timestamp, tableIds, ranges); + return Objects.hash(parentRepairSession, isGlobal, previewKind, isIncremental, timestamp, tableIds, ranges); } - public static class PrepareMessageSerializer implements MessageSerializer + private static final String MIXED_MODE_ERROR = "Some nodes involved in repair are on an incompatible major version. " + + "Repair is not supported in mixed major version clusters."; + + public static final IVersionedSerializer serializer = new IVersionedSerializer() { public void serialize(PrepareMessage message, DataOutputPlus out, int version) throws IOException { + Preconditions.checkArgument(version == MessagingService.current_version, MIXED_MODE_ERROR); + out.writeInt(message.tableIds.size()); for (TableId tableId : message.tableIds) tableId.serialize(out); @@ -92,7 +99,7 @@ public void serialize(PrepareMessage message, DataOutputPlus out, int version) t out.writeInt(message.ranges.size()); for (Range r : message.ranges) { - MessagingService.validatePartitioner(r); + IPartitioner.validate(r); Range.tokenSerializer.serialize(r, out, version); } out.writeBoolean(message.isIncremental); @@ -103,6 +110,8 @@ public void serialize(PrepareMessage message, DataOutputPlus out, int version) t public PrepareMessage deserialize(DataInputPlus in, int version) throws IOException { + Preconditions.checkArgument(version == MessagingService.current_version, MIXED_MODE_ERROR); + int tableIdCount = in.readInt(); List tableIds = new ArrayList<>(tableIdCount); for (int i = 0; i < tableIdCount; i++) @@ -111,7 +120,7 @@ public PrepareMessage deserialize(DataInputPlus in, int version) throws IOExcept int rangeCount = in.readInt(); List> ranges = new ArrayList<>(rangeCount); for (int i = 0; i < rangeCount; i++) - ranges.add((Range) Range.tokenSerializer.deserialize(in, MessagingService.globalPartitioner(), version)); + ranges.add((Range) Range.tokenSerializer.deserialize(in, IPartitioner.global(), version)); boolean isIncremental = in.readBoolean(); long timestamp = in.readLong(); boolean isGlobal = in.readBoolean(); @@ -121,6 +130,8 @@ public PrepareMessage deserialize(DataInputPlus in, int version) throws IOExcept public long serializedSize(PrepareMessage message, int version) { + Preconditions.checkArgument(version == MessagingService.current_version, MIXED_MODE_ERROR); + long size; size = TypeSizes.sizeof(message.tableIds.size()); for (TableId tableId : message.tableIds) @@ -135,7 +146,7 @@ public long serializedSize(PrepareMessage message, int version) size += TypeSizes.sizeof(message.previewKind.getSerializationVal()); return size; } - } + }; @Override public String toString() diff --git a/src/java/org/apache/cassandra/repair/messages/RepairMessage.java b/src/java/org/apache/cassandra/repair/messages/RepairMessage.java index 09c60604552f..3137b4e474ae 100644 --- a/src/java/org/apache/cassandra/repair/messages/RepairMessage.java +++ b/src/java/org/apache/cassandra/repair/messages/RepairMessage.java @@ -17,15 +17,6 @@ */ package org.apache.cassandra.repair.messages; -import java.io.IOException; - -import com.google.common.base.Preconditions; - -import org.apache.cassandra.io.IVersionedSerializer; -import org.apache.cassandra.io.util.DataInputPlus; -import org.apache.cassandra.io.util.DataOutputPlus; -import org.apache.cassandra.net.MessageOut; -import org.apache.cassandra.net.MessagingService; import org.apache.cassandra.repair.RepairJobDesc; /** @@ -35,91 +26,10 @@ */ public abstract class RepairMessage { - public static final IVersionedSerializer serializer = new RepairMessageSerializer(); - - public static interface MessageSerializer extends IVersionedSerializer {} - - public static final int MIN_MESSAGING_VERSION = MessagingService.VERSION_40; - private static final String MIXED_MODE_ERROR = "Some nodes involved in repair are on an incompatible major version. " + - "Repair is not supported in mixed major version clusters."; - - public enum Type - { - VALIDATION_REQUEST(0, ValidationRequest.serializer), - VALIDATION_COMPLETE(1, ValidationComplete.serializer), - SYNC_REQUEST(2, SyncRequest.serializer), - SYNC_COMPLETE(3, SyncComplete.serializer), - PREPARE_MESSAGE(5, PrepareMessage.serializer), - SNAPSHOT(6, SnapshotMessage.serializer), - CLEANUP(7, CleanupMessage.serializer), - - CONSISTENT_REQUEST(8, PrepareConsistentRequest.serializer), - CONSISTENT_RESPONSE(9, PrepareConsistentResponse.serializer), - FINALIZE_PROPOSE(10, FinalizePropose.serializer), - FINALIZE_PROMISE(11, FinalizePromise.serializer), - FINALIZE_COMMIT(12, FinalizeCommit.serializer), - FAILED_SESSION(13, FailSession.serializer), - STATUS_REQUEST(14, StatusRequest.serializer), - STATUS_RESPONSE(15, StatusResponse.serializer), - ASYMMETRIC_SYNC_REQUEST(16, AsymmetricSyncRequest.serializer); - - private final byte type; - private final MessageSerializer serializer; - - private Type(int type, MessageSerializer serializer) - { - this.type = (byte) type; - this.serializer = serializer; - } - - public static Type fromByte(byte b) - { - for (Type t : values()) - { - if (t.type == b) - return t; - } - throw new IllegalArgumentException("Unknown RepairMessage.Type: " + b); - } - } - - public final Type messageType; public final RepairJobDesc desc; - protected RepairMessage(Type messageType, RepairJobDesc desc) + protected RepairMessage(RepairJobDesc desc) { - this.messageType = messageType; this.desc = desc; } - - public MessageOut createMessage() - { - return new MessageOut<>(MessagingService.Verb.REPAIR_MESSAGE, this, RepairMessage.serializer); - } - - - public static class RepairMessageSerializer implements MessageSerializer - { - public void serialize(RepairMessage message, DataOutputPlus out, int version) throws IOException - { - Preconditions.checkArgument(version >= MIN_MESSAGING_VERSION, MIXED_MODE_ERROR); - out.write(message.messageType.type); - message.messageType.serializer.serialize(message, out, version); - } - - public RepairMessage deserialize(DataInputPlus in, int version) throws IOException - { - Preconditions.checkArgument(version >= MIN_MESSAGING_VERSION, MIXED_MODE_ERROR); - RepairMessage.Type messageType = RepairMessage.Type.fromByte(in.readByte()); - return messageType.serializer.deserialize(in, version); - } - - public long serializedSize(RepairMessage message, int version) - { - Preconditions.checkArgument(version >= MIN_MESSAGING_VERSION, MIXED_MODE_ERROR); - long size = 1; // for messageType byte - size += message.messageType.serializer.serializedSize(message, version); - return size; - } - } } diff --git a/src/java/org/apache/cassandra/repair/messages/SnapshotMessage.java b/src/java/org/apache/cassandra/repair/messages/SnapshotMessage.java index d4737d3e96ce..c18950a097b8 100644 --- a/src/java/org/apache/cassandra/repair/messages/SnapshotMessage.java +++ b/src/java/org/apache/cassandra/repair/messages/SnapshotMessage.java @@ -20,17 +20,16 @@ import java.io.IOException; import java.util.Objects; +import org.apache.cassandra.io.IVersionedSerializer; import org.apache.cassandra.io.util.DataInputPlus; import org.apache.cassandra.io.util.DataOutputPlus; import org.apache.cassandra.repair.RepairJobDesc; public class SnapshotMessage extends RepairMessage { - public final static MessageSerializer serializer = new SnapshotMessageSerializer(); - public SnapshotMessage(RepairJobDesc desc) { - super(Type.SNAPSHOT, desc); + super(desc); } @Override @@ -39,16 +38,16 @@ public boolean equals(Object o) if (!(o instanceof SnapshotMessage)) return false; SnapshotMessage other = (SnapshotMessage) o; - return messageType == other.messageType; + return desc.equals(other.desc); } @Override public int hashCode() { - return Objects.hash(messageType); + return Objects.hash(desc); } - public static class SnapshotMessageSerializer implements MessageSerializer + public static final IVersionedSerializer serializer = new IVersionedSerializer() { public void serialize(SnapshotMessage message, DataOutputPlus out, int version) throws IOException { @@ -65,5 +64,5 @@ public long serializedSize(SnapshotMessage message, int version) { return RepairJobDesc.serializer.serializedSize(message.desc, version); } - } + }; } diff --git a/src/java/org/apache/cassandra/repair/messages/StatusRequest.java b/src/java/org/apache/cassandra/repair/messages/StatusRequest.java index f6a2b827fee0..09354e63503b 100644 --- a/src/java/org/apache/cassandra/repair/messages/StatusRequest.java +++ b/src/java/org/apache/cassandra/repair/messages/StatusRequest.java @@ -21,6 +21,7 @@ import java.io.IOException; import java.util.UUID; +import org.apache.cassandra.io.IVersionedSerializer; import org.apache.cassandra.io.util.DataInputPlus; import org.apache.cassandra.io.util.DataOutputPlus; import org.apache.cassandra.utils.UUIDSerializer; @@ -31,7 +32,7 @@ public class StatusRequest extends RepairMessage public StatusRequest(UUID sessionID) { - super(Type.STATUS_REQUEST, null); + super(null); this.sessionID = sessionID; } @@ -57,7 +58,7 @@ public String toString() '}'; } - public static MessageSerializer serializer = new MessageSerializer() + public static final IVersionedSerializer serializer = new IVersionedSerializer() { public void serialize(StatusRequest msg, DataOutputPlus out, int version) throws IOException { diff --git a/src/java/org/apache/cassandra/repair/messages/StatusResponse.java b/src/java/org/apache/cassandra/repair/messages/StatusResponse.java index 99eb76ba664d..e62d337df052 100644 --- a/src/java/org/apache/cassandra/repair/messages/StatusResponse.java +++ b/src/java/org/apache/cassandra/repair/messages/StatusResponse.java @@ -22,6 +22,7 @@ import java.util.UUID; import org.apache.cassandra.db.TypeSizes; +import org.apache.cassandra.io.IVersionedSerializer; import org.apache.cassandra.io.util.DataInputPlus; import org.apache.cassandra.io.util.DataOutputPlus; import org.apache.cassandra.repair.consistent.ConsistentSession; @@ -34,7 +35,7 @@ public class StatusResponse extends RepairMessage public StatusResponse(UUID sessionID, ConsistentSession.State state) { - super(Type.STATUS_RESPONSE, null); + super(null); assert sessionID != null; assert state != null; this.sessionID = sessionID; @@ -67,7 +68,7 @@ public String toString() '}'; } - public static final MessageSerializer serializer = new MessageSerializer() + public static final IVersionedSerializer serializer = new IVersionedSerializer() { public void serialize(StatusResponse msg, DataOutputPlus out, int version) throws IOException { diff --git a/src/java/org/apache/cassandra/repair/messages/SyncRequest.java b/src/java/org/apache/cassandra/repair/messages/SyncRequest.java index a0bf4e2f90de..341455f7cdf6 100644 --- a/src/java/org/apache/cassandra/repair/messages/SyncRequest.java +++ b/src/java/org/apache/cassandra/repair/messages/SyncRequest.java @@ -25,16 +25,18 @@ import org.apache.cassandra.db.TypeSizes; import org.apache.cassandra.dht.AbstractBounds; +import org.apache.cassandra.dht.IPartitioner; import org.apache.cassandra.dht.Range; import org.apache.cassandra.dht.Token; +import org.apache.cassandra.io.IVersionedSerializer; import org.apache.cassandra.io.util.DataInputPlus; import org.apache.cassandra.io.util.DataOutputPlus; import org.apache.cassandra.locator.InetAddressAndPort; -import org.apache.cassandra.net.CompactEndpointSerializationHelper; -import org.apache.cassandra.net.MessagingService; import org.apache.cassandra.repair.RepairJobDesc; import org.apache.cassandra.streaming.PreviewKind; +import static org.apache.cassandra.locator.InetAddressAndPort.Serializer.inetAddressAndPortSerializer; + /** * Body part of SYNC_REQUEST repair message. * Request {@code src} node to sync data with {@code dst} node for range {@code ranges}. @@ -43,8 +45,6 @@ */ public class SyncRequest extends RepairMessage { - public static MessageSerializer serializer = new SyncRequestSerializer(); - public final InetAddressAndPort initiator; public final InetAddressAndPort src; public final InetAddressAndPort dst; @@ -53,7 +53,7 @@ public class SyncRequest extends RepairMessage public SyncRequest(RepairJobDesc desc, InetAddressAndPort initiator, InetAddressAndPort src, InetAddressAndPort dst, Collection> ranges, PreviewKind previewKind) { - super(Type.SYNC_REQUEST, desc); + super(desc); this.initiator = initiator; this.src = src; this.dst = dst; @@ -67,8 +67,7 @@ public boolean equals(Object o) if (!(o instanceof SyncRequest)) return false; SyncRequest req = (SyncRequest)o; - return messageType == req.messageType && - desc.equals(req.desc) && + return desc.equals(req.desc) && initiator.equals(req.initiator) && src.equals(req.src) && dst.equals(req.dst) && @@ -79,21 +78,21 @@ public boolean equals(Object o) @Override public int hashCode() { - return Objects.hash(messageType, desc, initiator, src, dst, ranges, previewKind); + return Objects.hash(desc, initiator, src, dst, ranges, previewKind); } - public static class SyncRequestSerializer implements MessageSerializer + public static final IVersionedSerializer serializer = new IVersionedSerializer() { public void serialize(SyncRequest message, DataOutputPlus out, int version) throws IOException { RepairJobDesc.serializer.serialize(message.desc, out, version); - CompactEndpointSerializationHelper.instance.serialize(message.initiator, out, version); - CompactEndpointSerializationHelper.instance.serialize(message.src, out, version); - CompactEndpointSerializationHelper.instance.serialize(message.dst, out, version); + inetAddressAndPortSerializer.serialize(message.initiator, out, version); + inetAddressAndPortSerializer.serialize(message.src, out, version); + inetAddressAndPortSerializer.serialize(message.dst, out, version); out.writeInt(message.ranges.size()); for (Range range : message.ranges) { - MessagingService.validatePartitioner(range); + IPartitioner.validate(range); AbstractBounds.tokenSerializer.serialize(range, out, version); } out.writeInt(message.previewKind.getSerializationVal()); @@ -102,13 +101,13 @@ public void serialize(SyncRequest message, DataOutputPlus out, int version) thro public SyncRequest deserialize(DataInputPlus in, int version) throws IOException { RepairJobDesc desc = RepairJobDesc.serializer.deserialize(in, version); - InetAddressAndPort owner = CompactEndpointSerializationHelper.instance.deserialize(in, version); - InetAddressAndPort src = CompactEndpointSerializationHelper.instance.deserialize(in, version); - InetAddressAndPort dst = CompactEndpointSerializationHelper.instance.deserialize(in, version); + InetAddressAndPort owner = inetAddressAndPortSerializer.deserialize(in, version); + InetAddressAndPort src = inetAddressAndPortSerializer.deserialize(in, version); + InetAddressAndPort dst = inetAddressAndPortSerializer.deserialize(in, version); int rangesCount = in.readInt(); List> ranges = new ArrayList<>(rangesCount); for (int i = 0; i < rangesCount; ++i) - ranges.add((Range) AbstractBounds.tokenSerializer.deserialize(in, MessagingService.globalPartitioner(), version)); + ranges.add((Range) AbstractBounds.tokenSerializer.deserialize(in, IPartitioner.global(), version)); PreviewKind previewKind = PreviewKind.deserialize(in.readInt()); return new SyncRequest(desc, owner, src, dst, ranges, previewKind); } @@ -116,14 +115,14 @@ public SyncRequest deserialize(DataInputPlus in, int version) throws IOException public long serializedSize(SyncRequest message, int version) { long size = RepairJobDesc.serializer.serializedSize(message.desc, version); - size += 3 * CompactEndpointSerializationHelper.instance.serializedSize(message.initiator, version); + size += 3 * inetAddressAndPortSerializer.serializedSize(message.initiator, version); size += TypeSizes.sizeof(message.ranges.size()); for (Range range : message.ranges) size += AbstractBounds.tokenSerializer.serializedSize(range, version); size += TypeSizes.sizeof(message.previewKind.getSerializationVal()); return size; } - } + }; @Override public String toString() diff --git a/src/java/org/apache/cassandra/repair/messages/SyncComplete.java b/src/java/org/apache/cassandra/repair/messages/SyncResponse.java similarity index 79% rename from src/java/org/apache/cassandra/repair/messages/SyncComplete.java rename to src/java/org/apache/cassandra/repair/messages/SyncResponse.java index c51d1fd0a564..e7e7985fff34 100644 --- a/src/java/org/apache/cassandra/repair/messages/SyncComplete.java +++ b/src/java/org/apache/cassandra/repair/messages/SyncResponse.java @@ -23,6 +23,7 @@ import java.util.Objects; import org.apache.cassandra.db.TypeSizes; +import org.apache.cassandra.io.IVersionedSerializer; import org.apache.cassandra.io.util.DataInputPlus; import org.apache.cassandra.io.util.DataOutputPlus; import org.apache.cassandra.locator.InetAddressAndPort; @@ -34,10 +35,8 @@ * * @since 2.0 */ -public class SyncComplete extends RepairMessage +public class SyncResponse extends RepairMessage { - public static final MessageSerializer serializer = new SyncCompleteSerializer(); - /** nodes that involved in this sync */ public final SyncNodePair nodes; /** true if sync success, false otherwise */ @@ -45,17 +44,17 @@ public class SyncComplete extends RepairMessage public final List summaries; - public SyncComplete(RepairJobDesc desc, SyncNodePair nodes, boolean success, List summaries) + public SyncResponse(RepairJobDesc desc, SyncNodePair nodes, boolean success, List summaries) { - super(Type.SYNC_COMPLETE, desc); + super(desc); this.nodes = nodes; this.success = success; this.summaries = summaries; } - public SyncComplete(RepairJobDesc desc, InetAddressAndPort endpoint1, InetAddressAndPort endpoint2, boolean success, List summaries) + public SyncResponse(RepairJobDesc desc, InetAddressAndPort endpoint1, InetAddressAndPort endpoint2, boolean success, List summaries) { - super(Type.SYNC_COMPLETE, desc); + super(desc); this.summaries = summaries; this.nodes = new SyncNodePair(endpoint1, endpoint2); this.success = success; @@ -64,11 +63,10 @@ public SyncComplete(RepairJobDesc desc, InetAddressAndPort endpoint1, InetAddres @Override public boolean equals(Object o) { - if (!(o instanceof SyncComplete)) + if (!(o instanceof SyncResponse)) return false; - SyncComplete other = (SyncComplete)o; - return messageType == other.messageType && - desc.equals(other.desc) && + SyncResponse other = (SyncResponse)o; + return desc.equals(other.desc) && success == other.success && nodes.equals(other.nodes) && summaries.equals(other.summaries); @@ -77,12 +75,12 @@ public boolean equals(Object o) @Override public int hashCode() { - return Objects.hash(messageType, desc, success, nodes, summaries); + return Objects.hash(desc, success, nodes, summaries); } - private static class SyncCompleteSerializer implements MessageSerializer + public static final IVersionedSerializer serializer = new IVersionedSerializer() { - public void serialize(SyncComplete message, DataOutputPlus out, int version) throws IOException + public void serialize(SyncResponse message, DataOutputPlus out, int version) throws IOException { RepairJobDesc.serializer.serialize(message.desc, out, version); SyncNodePair.serializer.serialize(message.nodes, out, version); @@ -95,7 +93,7 @@ public void serialize(SyncComplete message, DataOutputPlus out, int version) thr } } - public SyncComplete deserialize(DataInputPlus in, int version) throws IOException + public SyncResponse deserialize(DataInputPlus in, int version) throws IOException { RepairJobDesc desc = RepairJobDesc.serializer.deserialize(in, version); SyncNodePair nodes = SyncNodePair.serializer.deserialize(in, version); @@ -108,10 +106,10 @@ public SyncComplete deserialize(DataInputPlus in, int version) throws IOExceptio summaries.add(SessionSummary.serializer.deserialize(in, version)); } - return new SyncComplete(desc, nodes, success, summaries); + return new SyncResponse(desc, nodes, success, summaries); } - public long serializedSize(SyncComplete message, int version) + public long serializedSize(SyncResponse message, int version) { long size = RepairJobDesc.serializer.serializedSize(message.desc, version); size += SyncNodePair.serializer.serializedSize(message.nodes, version); @@ -125,5 +123,5 @@ public long serializedSize(SyncComplete message, int version) return size; } - } + }; } diff --git a/src/java/org/apache/cassandra/repair/messages/ValidationRequest.java b/src/java/org/apache/cassandra/repair/messages/ValidationRequest.java index 646624462ad5..f9a1f4e2be76 100644 --- a/src/java/org/apache/cassandra/repair/messages/ValidationRequest.java +++ b/src/java/org/apache/cassandra/repair/messages/ValidationRequest.java @@ -20,6 +20,7 @@ import java.io.IOException; import org.apache.cassandra.db.TypeSizes; +import org.apache.cassandra.io.IVersionedSerializer; import org.apache.cassandra.io.util.DataInputPlus; import org.apache.cassandra.io.util.DataOutputPlus; import org.apache.cassandra.repair.RepairJobDesc; @@ -31,13 +32,11 @@ */ public class ValidationRequest extends RepairMessage { - public static MessageSerializer serializer = new ValidationRequestSerializer(); - public final int nowInSec; public ValidationRequest(RepairJobDesc desc, int nowInSec) { - super(Type.VALIDATION_REQUEST, desc); + super(desc); this.nowInSec = nowInSec; } @@ -65,7 +64,7 @@ public int hashCode() return nowInSec; } - public static class ValidationRequestSerializer implements MessageSerializer + public static final IVersionedSerializer serializer = new IVersionedSerializer() { public void serialize(ValidationRequest message, DataOutputPlus out, int version) throws IOException { @@ -85,5 +84,5 @@ public long serializedSize(ValidationRequest message, int version) size += TypeSizes.sizeof(message.nowInSec); return size; } - } + }; } diff --git a/src/java/org/apache/cassandra/repair/messages/ValidationComplete.java b/src/java/org/apache/cassandra/repair/messages/ValidationResponse.java similarity index 69% rename from src/java/org/apache/cassandra/repair/messages/ValidationComplete.java rename to src/java/org/apache/cassandra/repair/messages/ValidationResponse.java index 704bffb09353..d9f44677bcdc 100644 --- a/src/java/org/apache/cassandra/repair/messages/ValidationComplete.java +++ b/src/java/org/apache/cassandra/repair/messages/ValidationResponse.java @@ -21,6 +21,7 @@ import java.util.Objects; import org.apache.cassandra.db.TypeSizes; +import org.apache.cassandra.io.IVersionedSerializer; import org.apache.cassandra.io.util.DataInputPlus; import org.apache.cassandra.io.util.DataOutputPlus; import org.apache.cassandra.repair.RepairJobDesc; @@ -31,22 +32,20 @@ * * @since 2.0 */ -public class ValidationComplete extends RepairMessage +public class ValidationResponse extends RepairMessage { - public static MessageSerializer serializer = new ValidationCompleteSerializer(); - /** Merkle hash tree response. Null if validation failed. */ public final MerkleTrees trees; - public ValidationComplete(RepairJobDesc desc) + public ValidationResponse(RepairJobDesc desc) { - super(Type.VALIDATION_COMPLETE, desc); + super(desc); trees = null; } - public ValidationComplete(RepairJobDesc desc, MerkleTrees trees) + public ValidationResponse(RepairJobDesc desc, MerkleTrees trees) { - super(Type.VALIDATION_COMPLETE, desc); + super(desc); assert trees != null; this.trees = trees; } @@ -56,26 +55,34 @@ public boolean success() return trees != null; } + /** + * @return a new {@link ValidationResponse} instance with all trees moved off heap, or {@code this} + * if it's a failure response. + */ + public ValidationResponse tryMoveOffHeap() throws IOException + { + return trees == null ? this : new ValidationResponse(desc, trees.tryMoveOffHeap()); + } + @Override public boolean equals(Object o) { - if (!(o instanceof ValidationComplete)) + if (!(o instanceof ValidationResponse)) return false; - ValidationComplete other = (ValidationComplete)o; - return messageType == other.messageType && - desc.equals(other.desc); + ValidationResponse other = (ValidationResponse)o; + return desc.equals(other.desc); } @Override public int hashCode() { - return Objects.hash(messageType, desc); + return Objects.hash(desc); } - private static class ValidationCompleteSerializer implements MessageSerializer + public static final IVersionedSerializer serializer = new IVersionedSerializer() { - public void serialize(ValidationComplete message, DataOutputPlus out, int version) throws IOException + public void serialize(ValidationResponse message, DataOutputPlus out, int version) throws IOException { RepairJobDesc.serializer.serialize(message.desc, out, version); out.writeBoolean(message.success()); @@ -83,7 +90,7 @@ public void serialize(ValidationComplete message, DataOutputPlus out, int versio MerkleTrees.serializer.serialize(message.trees, out, version); } - public ValidationComplete deserialize(DataInputPlus in, int version) throws IOException + public ValidationResponse deserialize(DataInputPlus in, int version) throws IOException { RepairJobDesc desc = RepairJobDesc.serializer.deserialize(in, version); boolean success = in.readBoolean(); @@ -91,13 +98,13 @@ public ValidationComplete deserialize(DataInputPlus in, int version) throws IOEx if (success) { MerkleTrees trees = MerkleTrees.serializer.deserialize(in, version); - return new ValidationComplete(desc, trees); + return new ValidationResponse(desc, trees); } - return new ValidationComplete(desc); + return new ValidationResponse(desc); } - public long serializedSize(ValidationComplete message, int version) + public long serializedSize(ValidationResponse message, int version) { long size = RepairJobDesc.serializer.serializedSize(message.desc, version); size += TypeSizes.sizeof(message.success()); @@ -105,5 +112,5 @@ public long serializedSize(ValidationComplete message, int version) size += MerkleTrees.serializer.serializedSize(message.trees, version); return size; } - } + }; } diff --git a/src/java/org/apache/cassandra/schema/CompressionParams.java b/src/java/org/apache/cassandra/schema/CompressionParams.java index 40a4be351b8b..102edd817e50 100644 --- a/src/java/org/apache/cassandra/schema/CompressionParams.java +++ b/src/java/org/apache/cassandra/schema/CompressionParams.java @@ -41,6 +41,7 @@ import org.apache.cassandra.io.compress.*; import org.apache.cassandra.io.util.DataInputPlus; import org.apache.cassandra.io.util.DataOutputPlus; +import org.apache.cassandra.net.MessagingService; import org.apache.cassandra.streaming.messages.StreamMessage; import static java.lang.String.format; @@ -596,7 +597,7 @@ public void serialize(CompressionParams parameters, DataOutputPlus out, int vers out.writeUTF(entry.getValue()); } out.writeInt(parameters.chunkLength()); - if (version >= StreamMessage.VERSION_40) + if (version >= MessagingService.VERSION_40) out.writeInt(parameters.maxCompressedLength); else if (parameters.maxCompressedLength != Integer.MAX_VALUE) @@ -616,7 +617,7 @@ public CompressionParams deserialize(DataInputPlus in, int version) throws IOExc } int chunkLength = in.readInt(); int minCompressRatio = Integer.MAX_VALUE; // Earlier Cassandra cannot use uncompressed chunks. - if (version >= StreamMessage.VERSION_40) + if (version >= MessagingService.VERSION_40) minCompressRatio = in.readInt(); CompressionParams parameters; @@ -641,7 +642,7 @@ public long serializedSize(CompressionParams parameters, int version) size += TypeSizes.sizeof(entry.getValue()); } size += TypeSizes.sizeof(parameters.chunkLength()); - if (version >= StreamMessage.VERSION_40) + if (version >= MessagingService.VERSION_40) size += TypeSizes.sizeof(parameters.maxCompressedLength()); return size; } diff --git a/src/java/org/apache/cassandra/schema/MigrationManager.java b/src/java/org/apache/cassandra/schema/MigrationManager.java index 32a6cf1f7dc2..69a72bf631af 100644 --- a/src/java/org/apache/cassandra/schema/MigrationManager.java +++ b/src/java/org/apache/cassandra/schema/MigrationManager.java @@ -38,11 +38,13 @@ import org.apache.cassandra.io.util.DataInputPlus; import org.apache.cassandra.io.util.DataOutputPlus; import org.apache.cassandra.locator.InetAddressAndPort; -import org.apache.cassandra.net.MessageOut; +import org.apache.cassandra.net.Message; import org.apache.cassandra.net.MessagingService; import org.apache.cassandra.schema.Keyspaces.KeyspacesDiff; import org.apache.cassandra.utils.FBUtilities; +import static org.apache.cassandra.net.Verb.SCHEMA_PUSH_REQ; + public class MigrationManager { private static final Logger logger = LoggerFactory.getLogger(MigrationManager.class); @@ -152,8 +154,8 @@ static boolean shouldPullSchemaFrom(InetAddressAndPort endpoint) * Don't request schema from nodes with a differnt or unknonw major version (may have incompatible schema) * Don't request schema from fat clients */ - return MessagingService.instance().knowsVersion(endpoint) - && MessagingService.instance().getRawVersion(endpoint) == MessagingService.current_version + return MessagingService.instance().versions.knows(endpoint) + && MessagingService.instance().versions.getRaw(endpoint) == MessagingService.current_version && !Gossiper.instance.isGossipOnlyMember(endpoint); } @@ -161,8 +163,8 @@ private static boolean shouldPushSchemaTo(InetAddressAndPort endpoint) { // only push schema to nodes with known and equal versions return !endpoint.equals(FBUtilities.getBroadcastAddressAndPort()) - && MessagingService.instance().knowsVersion(endpoint) - && MessagingService.instance().getRawVersion(endpoint) == MessagingService.current_version; + && MessagingService.instance().versions.knows(endpoint) + && MessagingService.instance().versions.getRaw(endpoint) == MessagingService.current_version; } public static boolean isReadyForBootstrap() @@ -315,14 +317,6 @@ private static void announce(Mutation.SimpleBuilder schema, boolean announceLoca announce(mutations); } - private static void pushSchemaMutation(InetAddressAndPort endpoint, Collection schema) - { - MessageOut> msg = new MessageOut<>(MessagingService.Verb.DEFINITIONS_UPDATE, - schema, - MigrationsSerializer.instance); - MessagingService.instance().sendOneWay(msg, endpoint); - } - // Returns a future on the local application of the schema private static void announce(Collection schema) { @@ -330,11 +324,12 @@ private static void announce(Collection schema) Set schemaDestinationEndpoints = new HashSet<>(); Set schemaEndpointsIgnored = new HashSet<>(); + Message> message = Message.out(SCHEMA_PUSH_REQ, schema); for (InetAddressAndPort endpoint : Gossiper.instance.getLiveMembers()) { if (shouldPushSchemaTo(endpoint)) { - pushSchemaMutation(endpoint, schema); + MessagingService.instance().send(message, endpoint); schemaDestinationEndpoints.add(endpoint); } else @@ -363,11 +358,12 @@ public static KeyspacesDiff announce(SchemaTransformation transformation, boolea Set schemaDestinationEndpoints = new HashSet<>(); Set schemaEndpointsIgnored = new HashSet<>(); + Message> message = Message.out(SCHEMA_PUSH_REQ, result.mutations); for (InetAddressAndPort endpoint : Gossiper.instance.getLiveMembers()) { if (shouldPushSchemaTo(endpoint)) { - pushSchemaMutation(endpoint, result.mutations); + MessagingService.instance().send(message, endpoint); schemaDestinationEndpoints.add(endpoint); } else diff --git a/src/java/org/apache/cassandra/schema/MigrationTask.java b/src/java/org/apache/cassandra/schema/MigrationTask.java index bf96fb27b959..3308893ffcc6 100644 --- a/src/java/org/apache/cassandra/schema/MigrationTask.java +++ b/src/java/org/apache/cassandra/schema/MigrationTask.java @@ -32,12 +32,14 @@ import org.apache.cassandra.exceptions.ConfigurationException; import org.apache.cassandra.gms.FailureDetector; import org.apache.cassandra.locator.InetAddressAndPort; -import org.apache.cassandra.net.IAsyncCallback; -import org.apache.cassandra.net.MessageIn; -import org.apache.cassandra.net.MessageOut; +import org.apache.cassandra.net.RequestCallback; +import org.apache.cassandra.net.Message; import org.apache.cassandra.net.MessagingService; import org.apache.cassandra.utils.WrappedRunnable; +import static org.apache.cassandra.net.NoPayload.noPayload; +import static org.apache.cassandra.net.Verb.SCHEMA_PULL_REQ; + final class MigrationTask extends WrappedRunnable { private static final Logger logger = LoggerFactory.getLogger(MigrationTask.class); @@ -78,32 +80,23 @@ public void runMayThrow() throws Exception return; } - MessageOut message = new MessageOut<>(MessagingService.Verb.MIGRATION_REQUEST, null, MigrationManager.MigrationsSerializer.instance); + Message message = Message.out(SCHEMA_PULL_REQ, noPayload); final CountDownLatch completionLatch = new CountDownLatch(1); - IAsyncCallback> cb = new IAsyncCallback>() + RequestCallback> cb = msg -> { - @Override - public void response(MessageIn> message) + try { - try - { - Schema.instance.mergeAndAnnounceVersion(message.payload); - } - catch (ConfigurationException e) - { - logger.error("Configuration exception merging remote schema", e); - } - finally - { - completionLatch.countDown(); - } + Schema.instance.mergeAndAnnounceVersion(msg.payload); } - - public boolean isLatencyForSnitch() + catch (ConfigurationException e) + { + logger.error("Configuration exception merging remote schema", e); + } + finally { - return false; + completionLatch.countDown(); } }; @@ -111,7 +104,7 @@ public boolean isLatencyForSnitch() if (monitoringBootstrapStates.contains(SystemKeyspace.getBootstrapState())) inflightTasks.offer(completionLatch); - MessagingService.instance().sendRR(message, endpoint, cb); + MessagingService.instance().sendWithCallback(message, endpoint, cb); SchemaMigrationDiagnostics.taskRequestSend(endpoint); } diff --git a/src/java/org/apache/cassandra/schema/SchemaKeyspace.java b/src/java/org/apache/cassandra/schema/SchemaKeyspace.java index d2abc03ce366..5498d6eafd7b 100644 --- a/src/java/org/apache/cassandra/schema/SchemaKeyspace.java +++ b/src/java/org/apache/cassandra/schema/SchemaKeyspace.java @@ -164,8 +164,8 @@ private SchemaKeyspace() + "table_name text," + "column_name text," + "dropped_time timestamp," - + "type text," + "kind text," + + "type text," + "PRIMARY KEY ((keyspace_name), table_name, column_name))"); private static final TableMetadata Triggers = @@ -345,7 +345,7 @@ public static void truncate() private static void flush() { if (!DatabaseDescriptor.isUnsafeSystem()) - ALL.forEach(table -> FBUtilities.waitOnFuture(getSchemaCFS(table).forceFlush())); + ALL.forEach(table -> FBUtilities.waitOnFuture(getSchemaCFS(table).forceFlushToSSTable())); } /** diff --git a/src/java/org/apache/cassandra/schema/SchemaMigrationEvent.java b/src/java/org/apache/cassandra/schema/SchemaMigrationEvent.java index 2c1723589808..45844b3403b9 100644 --- a/src/java/org/apache/cassandra/schema/SchemaMigrationEvent.java +++ b/src/java/org/apache/cassandra/schema/SchemaMigrationEvent.java @@ -85,8 +85,8 @@ enum MigrationManagerEventType if (endpoint == null) return; - if (MessagingService.instance().knowsVersion(endpoint)) - endpointMessagingVersion = MessagingService.instance().getRawVersion(endpoint); + if (MessagingService.instance().versions.knows(endpoint)) + endpointMessagingVersion = MessagingService.instance().versions.getRaw(endpoint); endpointGossipOnlyMember = Gossiper.instance.isGossipOnlyMember(endpoint); this.isAlive = FailureDetector.instance.isAlive(endpoint); diff --git a/src/java/org/apache/cassandra/schema/SchemaPullVerbHandler.java b/src/java/org/apache/cassandra/schema/SchemaPullVerbHandler.java index 45cf365d8505..ed30792fd490 100644 --- a/src/java/org/apache/cassandra/schema/SchemaPullVerbHandler.java +++ b/src/java/org/apache/cassandra/schema/SchemaPullVerbHandler.java @@ -24,27 +24,24 @@ import org.apache.cassandra.db.Mutation; import org.apache.cassandra.net.IVerbHandler; -import org.apache.cassandra.net.MessageIn; -import org.apache.cassandra.net.MessageOut; +import org.apache.cassandra.net.Message; import org.apache.cassandra.net.MessagingService; +import org.apache.cassandra.net.NoPayload; /** - * Sends it's current schema state in form of mutations in reply to the remote node's request. + * Sends it's current schema state in form of mutations in response to the remote node's request. * Such a request is made when one of the nodes, by means of Gossip, detects schema disagreement in the ring. */ -public final class SchemaPullVerbHandler implements IVerbHandler +public final class SchemaPullVerbHandler implements IVerbHandler { + public static final SchemaPullVerbHandler instance = new SchemaPullVerbHandler(); + private static final Logger logger = LoggerFactory.getLogger(SchemaPullVerbHandler.class); - public void doVerb(MessageIn message, int id) + public void doVerb(Message message) { - logger.trace("Received schema pull request from {}", message.from); - - MessageOut> response = - new MessageOut<>(MessagingService.Verb.INTERNAL_RESPONSE, - SchemaKeyspace.convertSchemaToMutations(), - MigrationManager.MigrationsSerializer.instance); - - MessagingService.instance().sendReply(response, id, message.from); + logger.trace("Received schema pull request from {}", message.from()); + Message> response = message.responseWith(SchemaKeyspace.convertSchemaToMutations()); + MessagingService.instance().send(response, message.from()); } } \ No newline at end of file diff --git a/src/java/org/apache/cassandra/schema/SchemaPushVerbHandler.java b/src/java/org/apache/cassandra/schema/SchemaPushVerbHandler.java index 358739a1b57e..8d1bb0ff0760 100644 --- a/src/java/org/apache/cassandra/schema/SchemaPushVerbHandler.java +++ b/src/java/org/apache/cassandra/schema/SchemaPushVerbHandler.java @@ -26,7 +26,7 @@ import org.apache.cassandra.concurrent.StageManager; import org.apache.cassandra.db.Mutation; import org.apache.cassandra.net.IVerbHandler; -import org.apache.cassandra.net.MessageIn; +import org.apache.cassandra.net.Message; /** * Called when node receives updated schema state from the schema migration coordinator node. @@ -36,13 +36,15 @@ */ public final class SchemaPushVerbHandler implements IVerbHandler> { + public static final SchemaPushVerbHandler instance = new SchemaPushVerbHandler(); + private static final Logger logger = LoggerFactory.getLogger(SchemaPushVerbHandler.class); - public void doVerb(final MessageIn> message, int id) + public void doVerb(final Message> message) { - logger.trace("Received schema push request from {}", message.from); + logger.trace("Received schema push request from {}", message.from()); - SchemaAnnouncementDiagnostics.schemataMutationsReceived(message.from); + SchemaAnnouncementDiagnostics.schemataMutationsReceived(message.from()); StageManager.getStage(Stage.MIGRATION).submit(() -> Schema.instance.mergeAndAnnounceVersion(message.payload)); } } \ No newline at end of file diff --git a/src/java/org/apache/cassandra/schema/SchemaVersionVerbHandler.java b/src/java/org/apache/cassandra/schema/SchemaVersionVerbHandler.java index 0a506e3cdda6..80090de5576a 100644 --- a/src/java/org/apache/cassandra/schema/SchemaVersionVerbHandler.java +++ b/src/java/org/apache/cassandra/schema/SchemaVersionVerbHandler.java @@ -23,24 +23,20 @@ import org.slf4j.LoggerFactory; import org.apache.cassandra.net.IVerbHandler; -import org.apache.cassandra.net.MessageIn; -import org.apache.cassandra.net.MessageOut; +import org.apache.cassandra.net.Message; import org.apache.cassandra.net.MessagingService; -import org.apache.cassandra.utils.UUIDSerializer; +import org.apache.cassandra.net.NoPayload; -public final class SchemaVersionVerbHandler implements IVerbHandler +public final class SchemaVersionVerbHandler implements IVerbHandler { + public static final SchemaVersionVerbHandler instance = new SchemaVersionVerbHandler(); + private final Logger logger = LoggerFactory.getLogger(SchemaVersionVerbHandler.class); - public void doVerb(MessageIn message, int id) + public void doVerb(Message message) { - logger.trace("Received schema version request from {}", message.from); - - MessageOut response = - new MessageOut<>(MessagingService.Verb.INTERNAL_RESPONSE, - Schema.instance.getVersion(), - UUIDSerializer.serializer); - - MessagingService.instance().sendReply(response, id, message.from); + logger.trace("Received schema version request from {}", message.from()); + Message response = message.responseWith(Schema.instance.getVersion()); + MessagingService.instance().send(response, message.from()); } } diff --git a/src/java/org/apache/cassandra/security/SSLFactory.java b/src/java/org/apache/cassandra/security/SSLFactory.java index 75198759f1f3..f6bbcd07af9a 100644 --- a/src/java/org/apache/cassandra/security/SSLFactory.java +++ b/src/java/org/apache/cassandra/security/SSLFactory.java @@ -56,6 +56,7 @@ import io.netty.handler.ssl.SupportedCipherSuiteFilter; import io.netty.util.ReferenceCountUtil; import org.apache.cassandra.concurrent.ScheduledExecutors; +import org.apache.cassandra.config.Config; import org.apache.cassandra.config.DatabaseDescriptor; import org.apache.cassandra.config.EncryptionOptions; @@ -83,6 +84,26 @@ public enum SocketType @VisibleForTesting static volatile boolean checkedExpiry = false; + // Isolate calls to OpenSsl.isAvailable to allow in-jvm dtests to disable tcnative openssl + // support. It creates a circular reference that prevents the instance class loader from being + // garbage collected. + static private final boolean openSslIsAvailable; + static + { + if (Boolean.getBoolean(Config.PROPERTY_PREFIX + "disable_tcactive_openssl")) + { + openSslIsAvailable = false; + } + else + { + openSslIsAvailable = OpenSsl.isAvailable(); + } + } + public static boolean openSslIsAvailable() + { + return openSslIsAvailable; + } + /** * Cached references of SSL Contexts */ @@ -233,17 +254,19 @@ public static String[] filterCipherSuites(String[] supported, String[] desired) public static SslContext getOrCreateSslContext(EncryptionOptions options, boolean buildTruststore, SocketType socketType) throws IOException { - return getOrCreateSslContext(options, buildTruststore, socketType, OpenSsl.isAvailable()); + return getOrCreateSslContext(options, buildTruststore, socketType, openSslIsAvailable()); } /** * Get a netty {@link SslContext} instance. */ @VisibleForTesting - static SslContext getOrCreateSslContext(EncryptionOptions options, boolean buildTruststore, - SocketType socketType, boolean useOpenSsl) throws IOException + static SslContext getOrCreateSslContext(EncryptionOptions options, + boolean buildTruststore, + SocketType socketType, + boolean useOpenSsl) throws IOException { - CacheKey key = new CacheKey(options, socketType); + CacheKey key = new CacheKey(options, socketType, useOpenSsl); SslContext sslContext; sslContext = cachedSslContexts.get(key); @@ -290,8 +313,8 @@ static SslContext createNettySslContext(EncryptionOptions options, boolean build // only set the cipher suites if the opertor has explicity configured values for it; else, use the default // for each ssl implemention (jdk or openssl) - if (options.cipher_suites != null && options.cipher_suites.length > 0) - builder.ciphers(Arrays.asList(options.cipher_suites), SupportedCipherSuiteFilter.INSTANCE); + if (options.cipher_suites != null && !options.cipher_suites.isEmpty()) + builder.ciphers(options.cipher_suites, SupportedCipherSuiteFilter.INSTANCE); if (buildTruststore) builder.trustManager(buildTrustManagerFactory(options)); @@ -385,8 +408,8 @@ public static void validateSslCerts(EncryptionOptions.ServerEncryptionOptions se // Ensure we're able to create both server & client SslContexts if (serverOpts != null && serverOpts.enabled) { - createNettySslContext(serverOpts, true, SocketType.SERVER, OpenSsl.isAvailable()); - createNettySslContext(serverOpts, true, SocketType.CLIENT, OpenSsl.isAvailable()); + createNettySslContext(serverOpts, true, SocketType.SERVER, openSslIsAvailable()); + createNettySslContext(serverOpts, true, SocketType.CLIENT, openSslIsAvailable()); } } catch (Exception e) @@ -399,8 +422,8 @@ public static void validateSslCerts(EncryptionOptions.ServerEncryptionOptions se // Ensure we're able to create both server & client SslContexts if (clientOpts != null && clientOpts.enabled) { - createNettySslContext(clientOpts, clientOpts.require_client_auth, SocketType.SERVER, OpenSsl.isAvailable()); - createNettySslContext(clientOpts, clientOpts.require_client_auth, SocketType.CLIENT, OpenSsl.isAvailable()); + createNettySslContext(clientOpts, clientOpts.require_client_auth, SocketType.SERVER, openSslIsAvailable()); + createNettySslContext(clientOpts, clientOpts.require_client_auth, SocketType.CLIENT, openSslIsAvailable()); } } catch (Exception e) @@ -413,11 +436,13 @@ static class CacheKey { private final EncryptionOptions encryptionOptions; private final SocketType socketType; + private final boolean useOpenSSL; - public CacheKey(EncryptionOptions encryptionOptions, SocketType socketType) + public CacheKey(EncryptionOptions encryptionOptions, SocketType socketType, boolean useOpenSSL) { this.encryptionOptions = encryptionOptions; this.socketType = socketType; + this.useOpenSSL = useOpenSSL; } public boolean equals(Object o) @@ -426,6 +451,7 @@ public boolean equals(Object o) if (o == null || getClass() != o.getClass()) return false; CacheKey cacheKey = (CacheKey) o; return (socketType == cacheKey.socketType && + useOpenSSL == cacheKey.useOpenSSL && Objects.equals(encryptionOptions, cacheKey.encryptionOptions)); } @@ -434,6 +460,7 @@ public int hashCode() int result = 0; result += 31 * socketType.hashCode(); result += 31 * encryptionOptions.hashCode(); + result += 31 * Boolean.hashCode(useOpenSSL); return result; } } diff --git a/src/java/org/apache/cassandra/serializers/CollectionSerializer.java b/src/java/org/apache/cassandra/serializers/CollectionSerializer.java index d988cc0f131c..3efdef987c13 100644 --- a/src/java/org/apache/cassandra/serializers/CollectionSerializer.java +++ b/src/java/org/apache/cassandra/serializers/CollectionSerializer.java @@ -166,7 +166,7 @@ protected ByteBuffer copyAsNewCollection(ByteBuffer input, int count, int startP ByteBuffer output = ByteBuffer.allocate(sizeLen + bodyLen); writeCollectionSize(output, count, version); output.position(0); - ByteBufferUtil.arrayCopy(input, startPos, output, sizeLen, bodyLen); + ByteBufferUtil.copyBytes(input, startPos, output, sizeLen, bodyLen); return output; } } diff --git a/src/java/org/apache/cassandra/service/AbstractWriteResponseHandler.java b/src/java/org/apache/cassandra/service/AbstractWriteResponseHandler.java index 1470cadc84bf..1889c79c28a4 100644 --- a/src/java/org/apache/cassandra/service/AbstractWriteResponseHandler.java +++ b/src/java/org/apache/cassandra/service/AbstractWriteResponseHandler.java @@ -20,7 +20,6 @@ import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; import java.util.stream.Collectors; @@ -40,13 +39,15 @@ import org.apache.cassandra.exceptions.WriteFailureException; import org.apache.cassandra.exceptions.WriteTimeoutException; import org.apache.cassandra.locator.InetAddressAndPort; -import org.apache.cassandra.net.IAsyncCallbackWithFailure; -import org.apache.cassandra.net.MessageIn; +import org.apache.cassandra.net.RequestCallback; +import org.apache.cassandra.net.Message; import org.apache.cassandra.schema.Schema; import org.apache.cassandra.utils.concurrent.SimpleCondition; +import static java.util.concurrent.TimeUnit.NANOSECONDS; -public abstract class AbstractWriteResponseHandler implements IAsyncCallbackWithFailure + +public abstract class AbstractWriteResponseHandler implements RequestCallback { protected static final Logger logger = LoggerFactory.getLogger(AbstractWriteResponseHandler.class); @@ -90,12 +91,12 @@ protected AbstractWriteResponseHandler(ReplicaPlan.ForTokenWrite replicaPlan, public void get() throws WriteTimeoutException, WriteFailureException { - long timeout = currentTimeout(); + long timeoutNanos = currentTimeoutNanos(); boolean success; try { - success = condition.await(timeout, TimeUnit.NANOSECONDS); + success = condition.await(timeoutNanos, NANOSECONDS); } catch (InterruptedException ex) { @@ -120,12 +121,12 @@ public void get() throws WriteTimeoutException, WriteFailureException } } - public final long currentTimeout() + public final long currentTimeoutNanos() { long requestTimeout = writeType == WriteType.COUNTER - ? DatabaseDescriptor.getCounterWriteRpcTimeout() - : DatabaseDescriptor.getWriteRpcTimeout(); - return TimeUnit.MILLISECONDS.toNanos(requestTimeout) - (System.nanoTime() - queryStartNanoTime); + ? DatabaseDescriptor.getCounterWriteRpcTimeout(NANOSECONDS) + : DatabaseDescriptor.getWriteRpcTimeout(NANOSECONDS); + return requestTimeout - (System.nanoTime() - queryStartNanoTime); } /** @@ -143,7 +144,7 @@ public void setIdealCLResponseHandler(AbstractWriteResponseHandler handler) * on whether the CL was achieved. Only call this after the subclass has completed all it's processing * since the subclass instance may be queried to find out if the CL was achieved. */ - protected final void logResponseToIdealCLDelegate(MessageIn m) + protected final void logResponseToIdealCLDelegate(Message m) { //Tracking ideal CL was not configured if (idealCLDelegate == null) @@ -162,7 +163,7 @@ protected final void logResponseToIdealCLDelegate(MessageIn m) //Let the delegate do full processing, this will loop back into the branch above //with idealCLDelegate == this, because the ideal write handler idealCLDelegate will always //be set to this in the delegate. - idealCLDelegate.response(m); + idealCLDelegate.onResponse(m); } } @@ -187,7 +188,7 @@ public final void expired() } /** - * @return the minimum number of endpoints that must reply. + * @return the minimum number of endpoints that must respond. */ protected int blockFor() { @@ -227,7 +228,7 @@ protected boolean waitingFor(InetAddressAndPort from) /** * null message means "response from local write" */ - public abstract void response(MessageIn msg); + public abstract void onResponse(Message msg); protected void signal() { @@ -251,6 +252,12 @@ public void onFailure(InetAddressAndPort from, RequestFailureReason failureReaso signal(); } + @Override + public boolean invokeOnFailure() + { + return true; + } + @Override public boolean supportsBackPressure() { @@ -301,12 +308,12 @@ public void maybeTryAdditionalReplicas(IMutation mutation, StorageProxy.WritePer timeout = Math.min(timeout, cf.additionalWriteLatencyNanos); // no latency information, or we're overloaded - if (timeout > TimeUnit.MILLISECONDS.toNanos(mutation.getTimeout())) + if (timeout > mutation.getTimeout(NANOSECONDS)) return; try { - if (!condition.await(timeout, TimeUnit.NANOSECONDS)) + if (!condition.await(timeout, NANOSECONDS)) { for (ColumnFamilyStore cf : cfs) cf.metric.additionalWrites.inc(); diff --git a/src/java/org/apache/cassandra/service/ActiveRepairService.java b/src/java/org/apache/cassandra/service/ActiveRepairService.java index 525bebaf9627..6f4c474fba5e 100644 --- a/src/java/org/apache/cassandra/service/ActiveRepairService.java +++ b/src/java/org/apache/cassandra/service/ActiveRepairService.java @@ -58,9 +58,8 @@ import org.apache.cassandra.gms.VersionedValue; import org.apache.cassandra.locator.InetAddressAndPort; import org.apache.cassandra.locator.TokenMetadata; -import org.apache.cassandra.net.IAsyncCallbackWithFailure; -import org.apache.cassandra.net.MessageIn; -import org.apache.cassandra.net.MessageOut; +import org.apache.cassandra.net.RequestCallback; +import org.apache.cassandra.net.Message; import org.apache.cassandra.net.MessagingService; import org.apache.cassandra.repair.CommonRange; import org.apache.cassandra.streaming.PreviewKind; @@ -72,7 +71,6 @@ import org.apache.cassandra.repair.messages.*; import org.apache.cassandra.schema.TableId; import org.apache.cassandra.utils.CassandraVersion; -import org.apache.cassandra.utils.Clock; import org.apache.cassandra.utils.FBUtilities; import org.apache.cassandra.utils.MBeanWrapper; import org.apache.cassandra.utils.Pair; @@ -80,6 +78,7 @@ import static com.google.common.collect.Iterables.concat; import static com.google.common.collect.Iterables.transform; +import static org.apache.cassandra.net.Verb.PREPARE_MSG; /** * ActiveRepairService is the starting point for manual "active" repairs. @@ -113,8 +112,6 @@ public static class ConsistentSessions private boolean registeredForEndpointChanges = false; - public static CassandraVersion SUPPORTS_GLOBAL_PREPARE_FLAG_VERSION = new CassandraVersion("2.2.1"); - private static final Logger logger = LoggerFactory.getLogger(ActiveRepairService.class); // singleton enforcement public static final ActiveRepairService instance = new ActiveRepairService(FailureDetector.instance, Gossiper.instance); @@ -248,6 +245,16 @@ public void run() return session; } + public boolean getUseOffheapMerkleTrees() + { + return DatabaseDescriptor.useOffheapMerkleTrees(); + } + + public void setUseOffheapMerkleTrees(boolean value) + { + DatabaseDescriptor.useOffheapMerkleTrees(value); + } + private void registerOnFdAndGossip(final T task) @@ -381,7 +388,7 @@ static long getRepairedAt(RepairOption options, boolean force) // end up skipping replicas if (options.isIncremental() && options.isGlobal() && ! force) { - return Clock.instance.currentTimeMillis(); + return System.currentTimeMillis(); } else { @@ -396,24 +403,27 @@ public UUID prepareForRepair(UUID parentRepairSession, InetAddressAndPort coordi final CountDownLatch prepareLatch = new CountDownLatch(endpoints.size()); final AtomicBoolean status = new AtomicBoolean(true); final Set failedNodes = Collections.synchronizedSet(new HashSet()); - IAsyncCallbackWithFailure callback = new IAsyncCallbackWithFailure() + RequestCallback callback = new RequestCallback() { - public void response(MessageIn msg) + @Override + public void onResponse(Message msg) { prepareLatch.countDown(); } - public boolean isLatencyForSnitch() - { - return false; - } - + @Override public void onFailure(InetAddressAndPort from, RequestFailureReason failureReason) { status.set(false); failedNodes.add(from.toString()); prepareLatch.countDown(); } + + @Override + public boolean invokeOnFailure() + { + return true; + } }; List tableIds = new ArrayList<>(columnFamilyStores.size()); @@ -425,8 +435,8 @@ public void onFailure(InetAddressAndPort from, RequestFailureReason failureReaso if (FailureDetector.instance.isAlive(neighbour)) { PrepareMessage message = new PrepareMessage(parentRepairSession, tableIds, options.getRanges(), options.isIncremental(), repairedAt, options.isGlobal(), options.getPreviewKind()); - MessageOut msg = message.createMessage(); - MessagingService.instance().sendRR(msg, neighbour, callback, DatabaseDescriptor.getRpcTimeout(), true); + Message msg = Message.out(PREPARE_MSG, message); + MessagingService.instance().sendWithCallback(msg, neighbour, callback); } else { @@ -515,21 +525,21 @@ public synchronized ParentRepairSession removeParentRepairSession(UUID parentSes return parentRepairSessions.remove(parentSessionId); } - public void handleMessage(InetAddressAndPort endpoint, RepairMessage message) + public void handleMessage(Message message) { - RepairJobDesc desc = message.desc; + RepairJobDesc desc = message.payload.desc; RepairSession session = sessions.get(desc.sessionId); if (session == null) return; - switch (message.messageType) + switch (message.verb()) { - case VALIDATION_COMPLETE: - ValidationComplete validation = (ValidationComplete) message; - session.validationComplete(desc, endpoint, validation.trees); + case VALIDATION_RSP: + ValidationResponse validation = (ValidationResponse) message.payload; + session.validationComplete(desc, message.from(), validation.trees); break; - case SYNC_COMPLETE: + case SYNC_RSP: // one of replica is synced. - SyncComplete sync = (SyncComplete) message; + SyncResponse sync = (SyncResponse) message.payload; session.syncComplete(desc, sync.nodes, sync.success, sync.summaries); break; default: diff --git a/src/java/org/apache/cassandra/service/ActiveRepairServiceMBean.java b/src/java/org/apache/cassandra/service/ActiveRepairServiceMBean.java index 283d466b9c7d..d967280e83bc 100644 --- a/src/java/org/apache/cassandra/service/ActiveRepairServiceMBean.java +++ b/src/java/org/apache/cassandra/service/ActiveRepairServiceMBean.java @@ -30,4 +30,7 @@ public interface ActiveRepairServiceMBean public void setRepairSessionSpaceInMegabytes(int sizeInMegabytes); public int getRepairSessionSpaceInMegabytes(); + + public boolean getUseOffheapMerkleTrees(); + public void setUseOffheapMerkleTrees(boolean value); } diff --git a/src/java/org/apache/cassandra/service/BatchlogResponseHandler.java b/src/java/org/apache/cassandra/service/BatchlogResponseHandler.java index 63fbc729f345..b28f468af7e8 100644 --- a/src/java/org/apache/cassandra/service/BatchlogResponseHandler.java +++ b/src/java/org/apache/cassandra/service/BatchlogResponseHandler.java @@ -24,7 +24,7 @@ import org.apache.cassandra.exceptions.WriteFailureException; import org.apache.cassandra.exceptions.WriteTimeoutException; import org.apache.cassandra.locator.InetAddressAndPort; -import org.apache.cassandra.net.MessageIn; +import org.apache.cassandra.net.Message; public class BatchlogResponseHandler extends AbstractWriteResponseHandler { @@ -47,21 +47,21 @@ protected int ackCount() return wrapped.ackCount(); } - public void response(MessageIn msg) + public void onResponse(Message msg) { - wrapped.response(msg); + wrapped.onResponse(msg); if (requiredBeforeFinishUpdater.decrementAndGet(this) == 0) cleanup.ackMutation(); } - public boolean isLatencyForSnitch() + public void onFailure(InetAddressAndPort from, RequestFailureReason failureReason) { - return wrapped.isLatencyForSnitch(); + wrapped.onFailure(from, failureReason); } - public void onFailure(InetAddressAndPort from, RequestFailureReason failureReason) + public boolean invokeOnFailure() { - wrapped.onFailure(from, failureReason); + return wrapped.invokeOnFailure(); } public void get() throws WriteTimeoutException, WriteFailureException diff --git a/src/java/org/apache/cassandra/service/CassandraDaemon.java b/src/java/org/apache/cassandra/service/CassandraDaemon.java index b8f06f6ad209..9d05371c89f3 100644 --- a/src/java/org/apache/cassandra/service/CassandraDaemon.java +++ b/src/java/org/apache/cassandra/service/CassandraDaemon.java @@ -71,6 +71,8 @@ import org.apache.cassandra.utils.*; import org.apache.cassandra.security.ThreadAwareSecurityManager; +import static java.util.concurrent.TimeUnit.NANOSECONDS; + /** * The CassandraDaemon is an abstraction for a Cassandra daemon * service, which defines not only a way to activate and deactivate it, but also @@ -431,9 +433,9 @@ public void uncaughtException(Thread t, Throwable e) // schedule periodic recomputation of speculative retry thresholds ScheduledExecutors.optionalTasks.scheduleWithFixedDelay( () -> Keyspace.all().forEach(k -> k.getColumnFamilyStores().forEach(ColumnFamilyStore::updateSpeculationThreshold)), - DatabaseDescriptor.getReadRpcTimeout(), - DatabaseDescriptor.getReadRpcTimeout(), - TimeUnit.MILLISECONDS + DatabaseDescriptor.getReadRpcTimeout(NANOSECONDS), + DatabaseDescriptor.getReadRpcTimeout(NANOSECONDS), + NANOSECONDS ); // Native transport diff --git a/src/java/org/apache/cassandra/service/ClientState.java b/src/java/org/apache/cassandra/service/ClientState.java index 26ed271093d7..81574e64df17 100644 --- a/src/java/org/apache/cassandra/service/ClientState.java +++ b/src/java/org/apache/cassandra/service/ClientState.java @@ -31,6 +31,8 @@ import org.apache.cassandra.auth.*; import org.apache.cassandra.db.virtual.VirtualSchemaKeyspace; +import org.apache.cassandra.exceptions.RequestExecutionException; +import org.apache.cassandra.exceptions.RequestValidationException; import org.apache.cassandra.schema.TableMetadata; import org.apache.cassandra.schema.TableMetadataRef; import org.apache.cassandra.config.DatabaseDescriptor; @@ -323,12 +325,24 @@ public void setKeyspace(String ks) */ public void login(AuthenticatedUser user) { - if (user.isAnonymous() || user.canLogin()) + if (user.isAnonymous() || canLogin(user)) this.user = user; else throw new AuthenticationException(String.format("%s is not permitted to log in", user.getName())); } + private boolean canLogin(AuthenticatedUser user) + { + try + { + return user.canLogin(); + } + catch (RequestExecutionException | RequestValidationException e) + { + throw new AuthenticationException("Unable to perform authentication: " + e.getMessage(), e); + } + } + public void ensureAllKeyspacesPermission(Permission perm) { if (isInternal) diff --git a/src/java/org/apache/cassandra/service/DatacenterSyncWriteResponseHandler.java b/src/java/org/apache/cassandra/service/DatacenterSyncWriteResponseHandler.java index 4c892ffad993..1f536c7db3b8 100644 --- a/src/java/org/apache/cassandra/service/DatacenterSyncWriteResponseHandler.java +++ b/src/java/org/apache/cassandra/service/DatacenterSyncWriteResponseHandler.java @@ -26,7 +26,7 @@ import org.apache.cassandra.locator.NetworkTopologyStrategy; import org.apache.cassandra.locator.Replica; import org.apache.cassandra.locator.ReplicaPlan; -import org.apache.cassandra.net.MessageIn; +import org.apache.cassandra.net.Message; import org.apache.cassandra.db.ConsistencyLevel; import org.apache.cassandra.db.WriteType; @@ -65,13 +65,13 @@ public DatacenterSyncWriteResponseHandler(ReplicaPlan.ForTokenWrite replicaPlan, } } - public void response(MessageIn message) + public void onResponse(Message message) { try { String dataCenter = message == null ? DatabaseDescriptor.getLocalDataCenter() - : snitch.getDatacenter(message.from); + : snitch.getDatacenter(message.from()); responses.get(dataCenter).getAndDecrement(); acks.incrementAndGet(); @@ -96,10 +96,4 @@ protected int ackCount() { return acks.get(); } - - public boolean isLatencyForSnitch() - { - return false; - } - } diff --git a/src/java/org/apache/cassandra/service/DatacenterWriteResponseHandler.java b/src/java/org/apache/cassandra/service/DatacenterWriteResponseHandler.java index f30b4525fffc..a9583a3c3512 100644 --- a/src/java/org/apache/cassandra/service/DatacenterWriteResponseHandler.java +++ b/src/java/org/apache/cassandra/service/DatacenterWriteResponseHandler.java @@ -21,7 +21,7 @@ import org.apache.cassandra.locator.InOurDcTester; import org.apache.cassandra.locator.InetAddressAndPort; import org.apache.cassandra.locator.ReplicaPlan; -import org.apache.cassandra.net.MessageIn; +import org.apache.cassandra.net.Message; import java.util.function.Predicate; @@ -42,11 +42,11 @@ public DatacenterWriteResponseHandler(ReplicaPlan.ForTokenWrite replicaPlan, } @Override - public void response(MessageIn message) + public void onResponse(Message message) { - if (message == null || waitingFor(message.from)) + if (message == null || waitingFor(message.from())) { - super.response(message); + super.onResponse(message); } else { diff --git a/src/java/org/apache/cassandra/service/EchoVerbHandler.java b/src/java/org/apache/cassandra/service/EchoVerbHandler.java index 1cc52e9cc6b5..76f23d46d816 100644 --- a/src/java/org/apache/cassandra/service/EchoVerbHandler.java +++ b/src/java/org/apache/cassandra/service/EchoVerbHandler.java @@ -19,28 +19,23 @@ * under the License. * */ - - -import org.apache.cassandra.gms.EchoMessage; import org.apache.cassandra.net.IVerbHandler; -import org.apache.cassandra.net.MessageIn; -import org.apache.cassandra.net.MessageOut; +import org.apache.cassandra.net.Message; import org.apache.cassandra.net.MessagingService; +import org.apache.cassandra.net.NoPayload; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import static org.apache.cassandra.net.async.OutboundConnectionIdentifier.ConnectionType; - -public class EchoVerbHandler implements IVerbHandler +public class EchoVerbHandler implements IVerbHandler { + public static final EchoVerbHandler instance = new EchoVerbHandler(); + private static final Logger logger = LoggerFactory.getLogger(EchoVerbHandler.class); - public void doVerb(MessageIn message, int id) + public void doVerb(Message message) { - MessageOut echoMessage = new MessageOut(MessagingService.Verb.REQUEST_RESPONSE, EchoMessage.instance, - EchoMessage.serializer, ConnectionType.GOSSIP); - logger.trace("Sending a EchoMessage reply {}", message.from); - MessagingService.instance().sendReply(echoMessage, id, message.from); + logger.trace("Sending ECHO_RSP to {}", message.from()); + MessagingService.instance().send(message.emptyResponse(), message.from()); } } diff --git a/src/java/org/apache/cassandra/service/GCInspector.java b/src/java/org/apache/cassandra/service/GCInspector.java index 657d3adeacb8..e0a935dcade9 100644 --- a/src/java/org/apache/cassandra/service/GCInspector.java +++ b/src/java/org/apache/cassandra/service/GCInspector.java @@ -69,7 +69,16 @@ public class GCInspector implements NotificationListener, GCInspectorMXBean try { Class bitsClass = Class.forName("java.nio.Bits"); - Field f = bitsClass.getDeclaredField("totalCapacity"); + Field f; + try + { + f = bitsClass.getDeclaredField("totalCapacity"); + } + catch (NoSuchFieldException ex) + { + // in Java11 it changed name to "TOTAL_CAPACITY" + f = bitsClass.getDeclaredField("TOTAL_CAPACITY"); + } f.setAccessible(true); temp = f; } diff --git a/src/java/org/apache/cassandra/service/NativeTransportService.java b/src/java/org/apache/cassandra/service/NativeTransportService.java index 79acab1b02a0..66b50007be35 100644 --- a/src/java/org/apache/cassandra/service/NativeTransportService.java +++ b/src/java/org/apache/cassandra/service/NativeTransportService.java @@ -32,10 +32,9 @@ import io.netty.channel.epoll.Epoll; import io.netty.channel.epoll.EpollEventLoopGroup; import io.netty.channel.nio.NioEventLoopGroup; -import io.netty.util.concurrent.EventExecutor; import org.apache.cassandra.config.DatabaseDescriptor; import org.apache.cassandra.metrics.ClientMetrics; -import org.apache.cassandra.transport.RequestThreadPoolExecutor; +import org.apache.cassandra.transport.Message; import org.apache.cassandra.transport.Server; import org.apache.cassandra.utils.NativeLibrary; @@ -51,7 +50,6 @@ public class NativeTransportService private boolean initialized = false; private EventLoopGroup workerGroup; - private EventExecutor eventExecutorGroup; /** * Creates netty thread pools and event loops. @@ -62,9 +60,6 @@ synchronized void initialize() if (initialized) return; - // prepare netty resources - eventExecutorGroup = new RequestThreadPoolExecutor(); - if (useEpoll()) { workerGroup = new EpollEventLoopGroup(); @@ -81,7 +76,6 @@ synchronized void initialize() InetAddress nativeAddr = DatabaseDescriptor.getRpcAddress(); org.apache.cassandra.transport.Server.Builder builder = new org.apache.cassandra.transport.Server.Builder() - .withEventExecutor(eventExecutorGroup) .withEventLoopGroup(workerGroup) .withHost(nativeAddr); @@ -141,8 +135,7 @@ public void destroy() // shutdown executors used by netty for native transport server workerGroup.shutdownGracefully(3, 5, TimeUnit.SECONDS).awaitUninterruptibly(); - // shutdownGracefully not implemented yet in RequestThreadPoolExecutor - eventExecutorGroup.shutdown(); + Message.Dispatcher.shutdown(); } /** @@ -153,7 +146,7 @@ public static boolean useEpoll() final boolean enableEpoll = Boolean.parseBoolean(System.getProperty("cassandra.native.epoll.enabled", "true")); if (enableEpoll && !Epoll.isAvailable() && NativeLibrary.osType == NativeLibrary.OSType.LINUX) - logger.warn("epoll not available {}", Epoll.unavailabilityCause()); + logger.warn("epoll not available", Epoll.unavailabilityCause()); return enableEpoll && Epoll.isAvailable(); } @@ -174,12 +167,6 @@ EventLoopGroup getWorkerGroup() return workerGroup; } - @VisibleForTesting - EventExecutor getEventExecutor() - { - return eventExecutorGroup; - } - @VisibleForTesting Collection getServers() { diff --git a/src/java/org/apache/cassandra/service/PendingRangeCalculatorService.java b/src/java/org/apache/cassandra/service/PendingRangeCalculatorService.java index a3f6b5232372..1c6b18347ea4 100644 --- a/src/java/org/apache/cassandra/service/PendingRangeCalculatorService.java +++ b/src/java/org/apache/cassandra/service/PendingRangeCalculatorService.java @@ -23,6 +23,8 @@ import org.apache.cassandra.schema.Schema; import org.apache.cassandra.db.Keyspace; import org.apache.cassandra.locator.AbstractReplicationStrategy; +import org.apache.cassandra.utils.ExecutorUtils; + import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -121,9 +123,8 @@ public static void calculatePendingRanges(AbstractReplicationStrategy strategy, } @VisibleForTesting - public void shutdownExecutor() throws InterruptedException + public void shutdownAndWait(long timeout, TimeUnit unit) throws InterruptedException, TimeoutException { - executor.shutdown(); - executor.awaitTermination(60, TimeUnit.SECONDS); + ExecutorUtils.shutdownNowAndWait(timeout, unit, executor); } } diff --git a/src/java/org/apache/cassandra/service/SnapshotVerbHandler.java b/src/java/org/apache/cassandra/service/SnapshotVerbHandler.java index a9975338b780..cf2872b4b51e 100644 --- a/src/java/org/apache/cassandra/service/SnapshotVerbHandler.java +++ b/src/java/org/apache/cassandra/service/SnapshotVerbHandler.java @@ -23,24 +23,24 @@ import org.apache.cassandra.db.SnapshotCommand; import org.apache.cassandra.db.Keyspace; import org.apache.cassandra.net.IVerbHandler; -import org.apache.cassandra.net.MessageIn; -import org.apache.cassandra.net.MessageOut; +import org.apache.cassandra.net.Message; import org.apache.cassandra.net.MessagingService; public class SnapshotVerbHandler implements IVerbHandler { + public static final SnapshotVerbHandler instance = new SnapshotVerbHandler(); + private static final Logger logger = LoggerFactory.getLogger(SnapshotVerbHandler.class); - public void doVerb(MessageIn message, int id) + public void doVerb(Message message) { SnapshotCommand command = message.payload; if (command.clear_snapshot) - { Keyspace.clearSnapshot(command.snapshot_name, command.keyspace); - } else Keyspace.open(command.keyspace).getColumnFamilyStore(command.column_family).snapshot(command.snapshot_name); - logger.debug("Enqueuing response to snapshot request {} to {}", command.snapshot_name, message.from); - MessagingService.instance().sendReply(new MessageOut(MessagingService.Verb.INTERNAL_RESPONSE), id, message.from); + + logger.debug("Enqueuing response to snapshot request {} to {}", command.snapshot_name, message.from()); + MessagingService.instance().send(message.emptyResponse(), message.from()); } } diff --git a/src/java/org/apache/cassandra/service/StorageProxy.java b/src/java/org/apache/cassandra/service/StorageProxy.java index ce7867429b7a..d2dd956607b2 100644 --- a/src/java/org/apache/cassandra/service/StorageProxy.java +++ b/src/java/org/apache/cassandra/service/StorageProxy.java @@ -17,7 +17,6 @@ */ package org.apache.cassandra.service; -import java.io.File; import java.nio.ByteBuffer; import java.nio.file.Paths; import java.util.*; @@ -26,7 +25,6 @@ import java.util.concurrent.atomic.AtomicLong; import com.google.common.base.Preconditions; -import com.google.common.base.Strings; import com.google.common.cache.CacheLoader; import com.google.common.collect.*; import com.google.common.primitives.Ints; @@ -69,17 +67,25 @@ import org.apache.cassandra.service.paxos.Commit; import org.apache.cassandra.service.paxos.PaxosState; import org.apache.cassandra.service.paxos.PrepareCallback; -import org.apache.cassandra.service.paxos.PrepareResponse; -import org.apache.cassandra.service.paxos.PrepareVerbHandler; import org.apache.cassandra.service.paxos.ProposeCallback; -import org.apache.cassandra.service.paxos.ProposeVerbHandler; -import org.apache.cassandra.net.MessagingService.Verb; +import org.apache.cassandra.net.Verb; import org.apache.cassandra.tracing.Tracing; import org.apache.cassandra.triggers.TriggerExecutor; import org.apache.cassandra.utils.*; import org.apache.cassandra.utils.AbstractIterator; +import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static java.util.concurrent.TimeUnit.NANOSECONDS; +import static org.apache.cassandra.net.NoPayload.noPayload; +import static org.apache.cassandra.net.Verb.BATCH_STORE_REQ; +import static org.apache.cassandra.net.Verb.MUTATION_REQ; +import static org.apache.cassandra.net.Verb.PAXOS_COMMIT_REQ; +import static org.apache.cassandra.net.Verb.PAXOS_PREPARE_REQ; +import static org.apache.cassandra.net.Verb.PAXOS_PROPOSE_REQ; +import static org.apache.cassandra.net.Verb.TRUNCATE_REQ; import static org.apache.cassandra.service.BatchlogResponseHandler.BatchlogCleanup; +import static org.apache.cassandra.service.paxos.PrepareVerbHandler.doPrepare; +import static org.apache.cassandra.service.paxos.ProposeVerbHandler.doPropose; public class StorageProxy implements StorageProxyMBean { @@ -173,7 +179,7 @@ private StorageProxy() * 1. Prepare: the coordinator generates a ballot (timeUUID in our case) and asks replicas to (a) promise * not to accept updates from older ballots and (b) tell us about the most recent update it has already * accepted. - * 2. Accept: if a majority of replicas reply, the coordinator asks replicas to accept the value of the + * 2. Accept: if a majority of replicas respond, the coordinator asks replicas to accept the value of the * highest proposal ballot it heard about, or a new value if no in-progress proposals were reported. * 3. Commit (Learn): if a majority of replicas acknowledge the accept request, we can commit the new * value. @@ -219,8 +225,8 @@ public static RowIterator cas(String keyspaceName, consistencyForPaxos.validateForCas(); consistencyForCommit.validateForCasCommit(keyspaceName); - long timeout = TimeUnit.MILLISECONDS.toNanos(DatabaseDescriptor.getCasContentionTimeout()); - while (System.nanoTime() - queryStartNanoTime < timeout) + long timeoutNanos = DatabaseDescriptor.getCasContentionTimeout(NANOSECONDS); + while (System.nanoTime() - queryStartNanoTime < timeoutNanos) { // for simplicity, we'll do a single liveness check at the start of each attempt ReplicaPlan.ForPaxosWrite replicaPlan = ReplicaPlans.forPaxos(Keyspace.open(keyspaceName), key, consistencyForPaxos); @@ -276,7 +282,7 @@ public static RowIterator cas(String keyspaceName, Tracing.trace("Paxos proposal not accepted (pre-empted by a higher ballot)"); contentions++; - Uninterruptibles.sleepUninterruptibly(ThreadLocalRandom.current().nextInt(100), TimeUnit.MILLISECONDS); + Uninterruptibles.sleepUninterruptibly(ThreadLocalRandom.current().nextInt(100), MILLISECONDS); // continue to retry } @@ -332,11 +338,11 @@ private static PaxosBallotAndContention beginAndRepairPaxos(long queryStartNanoT ClientState state) throws WriteTimeoutException, WriteFailureException { - long timeout = TimeUnit.MILLISECONDS.toNanos(DatabaseDescriptor.getCasContentionTimeout()); + long timeoutNanos = DatabaseDescriptor.getCasContentionTimeout(NANOSECONDS); PrepareCallback summary = null; int contentions = 0; - while (System.nanoTime() - queryStartNanoTime < timeout) + while (System.nanoTime() - queryStartNanoTime < timeoutNanos) { // We want a timestamp that is guaranteed to be unique for that node (so that the ballot is globally unique), but if we've got a prepare rejected // already we also want to make sure we pick a timestamp that has a chance to be promised, i.e. one that is greater that the most recently known @@ -357,7 +363,7 @@ private static PaxosBallotAndContention beginAndRepairPaxos(long queryStartNanoT Tracing.trace("Some replicas have already promised a higher ballot than ours; aborting"); contentions++; // sleep a random amount to give the other proposer a chance to finish - Uninterruptibles.sleepUninterruptibly(ThreadLocalRandom.current().nextInt(100), TimeUnit.MILLISECONDS); + Uninterruptibles.sleepUninterruptibly(ThreadLocalRandom.current().nextInt(100), MILLISECONDS); continue; } @@ -392,7 +398,7 @@ private static PaxosBallotAndContention beginAndRepairPaxos(long queryStartNanoT Tracing.trace("Some replicas have already promised a higher ballot than ours; aborting"); // sleep a random amount to give the other proposer a chance to finish contentions++; - Uninterruptibles.sleepUninterruptibly(ThreadLocalRandom.current().nextInt(100), TimeUnit.MILLISECONDS); + Uninterruptibles.sleepUninterruptibly(ThreadLocalRandom.current().nextInt(100), MILLISECONDS); } continue; } @@ -426,43 +432,34 @@ private static PaxosBallotAndContention beginAndRepairPaxos(long queryStartNanoT */ private static void sendCommit(Commit commit, Iterable replicas) { - MessageOut message = new MessageOut(MessagingService.Verb.PAXOS_COMMIT, commit, Commit.serializer); + Message message = Message.out(PAXOS_COMMIT_REQ, commit); for (InetAddressAndPort target : replicas) - MessagingService.instance().sendOneWay(message, target); + MessagingService.instance().send(message, target); } private static PrepareCallback preparePaxos(Commit toPrepare, ReplicaPlan.ForPaxosWrite replicaPlan, long queryStartNanoTime) throws WriteTimeoutException { PrepareCallback callback = new PrepareCallback(toPrepare.update.partitionKey(), toPrepare.update.metadata(), replicaPlan.requiredParticipants(), replicaPlan.consistencyLevel(), queryStartNanoTime); - MessageOut message = new MessageOut(MessagingService.Verb.PAXOS_PREPARE, toPrepare, Commit.serializer); + Message message = Message.out(PAXOS_PREPARE_REQ, toPrepare); for (Replica replica: replicaPlan.contacts()) { if (replica.isSelf()) { - StageManager.getStage(MessagingService.verbStages.get(MessagingService.Verb.PAXOS_PREPARE)).execute(new Runnable() - { - public void run() + StageManager.getStage(PAXOS_PREPARE_REQ.stage).execute(() -> { + try { - try - { - MessageIn message = MessageIn.create(FBUtilities.getBroadcastAddressAndPort(), - PrepareVerbHandler.doPrepare(toPrepare), - Collections.emptyMap(), - MessagingService.Verb.INTERNAL_RESPONSE, - MessagingService.current_version); - callback.response(message); - } - catch (Exception ex) - { - logger.error("Failed paxos prepare locally", ex); - } + callback.onResponse(message.responseWith(doPrepare(toPrepare))); + } + catch (Exception ex) + { + logger.error("Failed paxos prepare locally", ex); } }); } else { - MessagingService.instance().sendRR(message, replica.endpoint(), callback); + MessagingService.instance().sendWithCallback(message, replica.endpoint(), callback); } } callback.await(); @@ -473,34 +470,26 @@ private static boolean proposePaxos(Commit proposal, ReplicaPlan.ForPaxosWrite r throws WriteTimeoutException { ProposeCallback callback = new ProposeCallback(replicaPlan.contacts().size(), replicaPlan.requiredParticipants(), !timeoutIfPartial, replicaPlan.consistencyLevel(), queryStartNanoTime); - MessageOut message = new MessageOut(MessagingService.Verb.PAXOS_PROPOSE, proposal, Commit.serializer); + Message message = Message.out(PAXOS_PROPOSE_REQ, proposal); for (Replica replica : replicaPlan.contacts()) { if (replica.isSelf()) { - StageManager.getStage(MessagingService.verbStages.get(MessagingService.Verb.PAXOS_PROPOSE)).execute(new Runnable() - { - public void run() + StageManager.getStage(PAXOS_PROPOSE_REQ.stage).execute(() -> { + try { - try - { - MessageIn message = MessageIn.create(FBUtilities.getBroadcastAddressAndPort(), - ProposeVerbHandler.doPropose(proposal), - Collections.emptyMap(), - MessagingService.Verb.INTERNAL_RESPONSE, - MessagingService.current_version); - callback.response(message); - } - catch (Exception ex) - { - logger.error("Failed paxos propose locally", ex); - } + Message response = message.responseWith(doPropose(proposal)); + callback.onResponse(response); + } + catch (Exception ex) + { + logger.error("Failed paxos propose locally", ex); } }); } else { - MessagingService.instance().sendRR(message, replica.endpoint(), callback); + MessagingService.instance().sendWithCallback(message, replica.endpoint(), callback); } } callback.await(); @@ -531,7 +520,7 @@ private static void commitPaxos(Commit proposal, ConsistencyLevel consistencyLev responseHandler.setSupportsBackPressure(false); } - MessageOut message = new MessageOut<>(MessagingService.Verb.PAXOS_COMMIT, proposal, Commit.serializer); + Message message = Message.outWithFlag(PAXOS_COMMIT_REQ, proposal, MessageFlag.CALL_BACK_ON_FAILURE); for (Replica replica : replicaPlan.liveAndDown()) { InetAddressAndPort destination = replica.endpoint(); @@ -544,11 +533,11 @@ private static void commitPaxos(Commit proposal, ConsistencyLevel consistencyLev if (replica.isSelf()) commitPaxosLocal(replica, message, responseHandler); else - MessagingService.instance().sendWriteRR(message, replica, responseHandler, allowHints && shouldHint(replica)); + MessagingService.instance().sendWriteWithCallback(message, replica, responseHandler, allowHints && shouldHint(replica)); } else { - MessagingService.instance().sendOneWay(message, destination); + MessagingService.instance().send(message, destination); } } else @@ -573,9 +562,9 @@ private static void commitPaxos(Commit proposal, ConsistencyLevel consistencyLev * submit a fake one that executes immediately on the mutation stage, but generates the necessary backpressure * signal for hints */ - private static void commitPaxosLocal(Replica localReplica, final MessageOut message, final AbstractWriteResponseHandler responseHandler) + private static void commitPaxosLocal(Replica localReplica, final Message message, final AbstractWriteResponseHandler responseHandler) { - StageManager.getStage(MessagingService.verbStages.get(MessagingService.Verb.PAXOS_COMMIT)).maybeExecuteImmediately(new LocalMutationRunnable(localReplica) + StageManager.getStage(PAXOS_COMMIT_REQ.stage).maybeExecuteImmediately(new LocalMutationRunnable(localReplica) { public void runMayThrow() { @@ -583,20 +572,20 @@ public void runMayThrow() { PaxosState.commit(message.payload); if (responseHandler != null) - responseHandler.response(null); + responseHandler.onResponse(null); } catch (Exception ex) { if (!(ex instanceof WriteTimeoutException)) logger.error("Failed to apply paxos commit locally : ", ex); - responseHandler.onFailure(FBUtilities.getBroadcastAddressAndPort(), RequestFailureReason.UNKNOWN); + responseHandler.onFailure(FBUtilities.getBroadcastAddressAndPort(), RequestFailureReason.forException(ex)); } } @Override protected Verb verb() { - return MessagingService.Verb.PAXOS_COMMIT; + return PAXOS_COMMIT_REQ; } }); } @@ -999,22 +988,22 @@ private static void syncWriteToBatchlog(Collection mutations, ReplicaP queryStartNanoTime); Batch batch = Batch.createLocal(uuid, FBUtilities.timestampMicros(), mutations); - MessageOut message = new MessageOut<>(MessagingService.Verb.BATCH_STORE, batch, Batch.serializer); + Message message = Message.out(BATCH_STORE_REQ, batch); for (Replica replica : replicaPlan.liveAndDown()) { logger.trace("Sending batchlog store request {} to {} for {} mutations", batch.id, replica, batch.size()); if (replica.isSelf()) - performLocally(Stage.MUTATION, replica, Optional.empty(), () -> BatchlogManager.store(batch), handler); + performLocally(Stage.MUTATION, replica, () -> BatchlogManager.store(batch), handler); else - MessagingService.instance().sendRR(message, replica.endpoint(), handler); + MessagingService.instance().sendWithCallback(message, replica.endpoint(), handler); } handler.get(); } private static void asyncRemoveFromBatchlog(ReplicaPlan.ForTokenWrite replicaPlan, UUID uuid) { - MessageOut message = new MessageOut<>(MessagingService.Verb.BATCH_REMOVE, uuid, UUIDSerializer.serializer); + Message message = Message.out(Verb.BATCH_REMOVE_REQ, uuid); for (Replica target : replicaPlan.contacts()) { if (logger.isTraceEnabled()) @@ -1023,7 +1012,7 @@ private static void asyncRemoveFromBatchlog(ReplicaPlan.ForTokenWrite replicaPla if (target.isSelf()) performLocally(Stage.MUTATION, target, () -> BatchlogManager.remove(uuid)); else - MessagingService.instance().sendOneWay(message, target.endpoint()); + MessagingService.instance().send(message, target.endpoint()); } } @@ -1040,7 +1029,7 @@ private static void asyncWriteBatchedMutations(List } catch (OverloadedException | WriteTimeoutException e) { - wrapper.handler.onFailure(FBUtilities.getBroadcastAddressAndPort(), RequestFailureReason.UNKNOWN); + wrapper.handler.onFailure(FBUtilities.getBroadcastAddressAndPort(), RequestFailureReason.forException(e)); } } } @@ -1136,7 +1125,7 @@ private static WriteResponseHandlerWrapper wrapViewBatchResponseHandler(Mutation AbstractWriteResponseHandler writeHandler = rs.getWriteResponseHandler(replicaPlan, () -> { long delay = Math.max(0, System.currentTimeMillis() - baseComplete.get()); - viewWriteMetrics.viewWriteLatency.update(delay, TimeUnit.MILLISECONDS); + viewWriteMetrics.viewWriteLatency.update(delay, MILLISECONDS); }, writeType, queryStartNanoTime); BatchlogResponseHandler batchHandler = new ViewWriteMetricsWrapped(writeHandler, batchConsistencyLevel.blockFor(keyspace), cleanup, queryStartNanoTime); return new WriteResponseHandlerWrapper(batchHandler, mutation); @@ -1184,7 +1173,7 @@ public static void sendToHintedReplicas(final Mutation mutation, // extra-datacenter replicas, grouped by dc Map> dcGroups = null; // only need to create a Message for non-local writes - MessageOut message = null; + Message message = null; boolean insertLocal = false; Replica localReplica = null; @@ -1207,7 +1196,7 @@ public static void sendToHintedReplicas(final Mutation mutation, { // belongs on a different server if (message == null) - message = mutation.createMessage(); + message = Message.outWithFlag(MUTATION_REQ, mutation, MessageFlag.CALL_BACK_ON_FAILURE); String dc = DatabaseDescriptor.getEndpointSnitch().getDatacenter(destination); @@ -1253,7 +1242,7 @@ public static void sendToHintedReplicas(final Mutation mutation, } if (backPressureHosts != null) - MessagingService.instance().applyBackPressure(backPressureHosts, responseHandler.currentTimeout()); + MessagingService.instance().applyBackPressure(backPressureHosts, responseHandler.currentTimeoutNanos()); if (endpointsToHint != null) submitHint(mutation, EndpointsForToken.copyOf(mutation.key().getToken(), endpointsToHint), responseHandler); @@ -1261,13 +1250,13 @@ public static void sendToHintedReplicas(final Mutation mutation, if (insertLocal) { Preconditions.checkNotNull(localReplica); - performLocally(stage, localReplica, Optional.of(mutation), mutation::apply, responseHandler); + performLocally(stage, localReplica, mutation::apply, responseHandler); } if (localDc != null) { for (Replica destination : localDc) - MessagingService.instance().sendWriteRR(message, destination, responseHandler, true); + MessagingService.instance().sendWriteWithCallback(message, destination, responseHandler, true); } if (dcGroups != null) { @@ -1293,33 +1282,34 @@ private static void checkHintOverload(Replica destination) } } - private static void sendMessagesToNonlocalDC(MessageOut message, + /* + * Send the message to the first replica of targets, and have it forward the message to others in its DC + * + * TODO: are targets shuffled? do we want them to be to spread out forwarding burden? + */ + private static void sendMessagesToNonlocalDC(Message message, EndpointsForToken targets, AbstractWriteResponseHandler handler) { - Iterator iter = targets.iterator(); - int[] messageIds = new int[targets.size()]; - Replica target = iter.next(); - - int idIdx = 0; - // Add the other destinations of the same message as a FORWARD_HEADER entry - while (iter.hasNext()) + if (targets.size() > 1) { - Replica destination = iter.next(); - int id = MessagingService.instance().addWriteCallback(handler, - message, - destination, - message.getTimeout(), - handler.replicaPlan.consistencyLevel(), - true); - messageIds[idIdx++] = id; - logger.trace("Adding FWD message to {}@{}", id, destination); + EndpointsForToken forwardToReplicas = targets.subList(1, targets.size()); + + for (Replica replica : forwardToReplicas) + { + MessagingService.instance().callbacks.addWithExpiration(handler, message, replica, handler.replicaPlan.consistencyLevel(), true); + logger.trace("Adding FWD message to {}@{}", message.id(), replica); + } + + // starting with 4.0, use the same message id for all replicas + long[] messageIds = new long[forwardToReplicas.size()]; + Arrays.fill(messageIds, message.id()); + + message = message.withForwardTo(new ForwardingInfo(forwardToReplicas.endpointList(), messageIds)); } - message = message.withParameter(ParameterType.FORWARD_TO, new ForwardToContainer(targets.endpoints(), messageIds)); - // send the combined message + forward headers - int id = MessagingService.instance().sendWriteRR(message, target, handler, true); - logger.trace("Sending message to {}@{}", id, target); + MessagingService.instance().sendWriteWithCallback(message, targets.get(0), handler, true); + logger.trace("Sending message to {}@{}", message.id(), targets.get(0)); } private static void performLocally(Stage stage, Replica localReplica, final Runnable runnable) @@ -1341,34 +1331,34 @@ public void runMayThrow() @Override protected Verb verb() { - return MessagingService.Verb.MUTATION; + return Verb.MUTATION_REQ; } }); } - private static void performLocally(Stage stage, Replica localReplica, Optional mutation, final Runnable runnable, final IAsyncCallbackWithFailure handler) + private static void performLocally(Stage stage, Replica localReplica, final Runnable runnable, final RequestCallback handler) { - StageManager.getStage(stage).maybeExecuteImmediately(new LocalMutationRunnable(localReplica, mutation) + StageManager.getStage(stage).maybeExecuteImmediately(new LocalMutationRunnable(localReplica) { public void runMayThrow() { try { runnable.run(); - handler.response(null); + handler.onResponse(null); } catch (Exception ex) { if (!(ex instanceof WriteTimeoutException)) logger.error("Failed to apply mutation locally : ", ex); - handler.onFailure(FBUtilities.getBroadcastAddressAndPort(), RequestFailureReason.UNKNOWN); + handler.onFailure(FBUtilities.getBroadcastAddressAndPort(), RequestFailureReason.forException(ex)); } } @Override protected Verb verb() { - return MessagingService.Verb.MUTATION; + return Verb.MUTATION_REQ; } }); } @@ -1410,7 +1400,8 @@ public static AbstractWriteResponseHandler mutateCounter(CounterMutat WriteType.COUNTER, queryStartNanoTime); Tracing.trace("Enqueuing counter update to {}", replica); - MessagingService.instance().sendWriteRR(cm.makeMutationMessage(), replica, responseHandler, false); + Message message = Message.outWithFlag(Verb.COUNTER_MUTATION_REQ, cm, MessageFlag.CALL_BACK_ON_FAILURE); + MessagingService.instance().sendWriteWithCallback(message, replica, responseHandler, false); return responseHandler; } } @@ -1479,7 +1470,7 @@ private static Runnable counterWriteTask(final IMutation mutation, final AbstractWriteResponseHandler responseHandler, final String localDataCenter) { - return new DroppableRunnable(MessagingService.Verb.COUNTER_MUTATION) + return new DroppableRunnable(Verb.COUNTER_MUTATION_REQ) { @Override public void runMayThrow() throws OverloadedException, WriteTimeoutException @@ -1487,7 +1478,7 @@ public void runMayThrow() throws OverloadedException, WriteTimeoutException assert mutation instanceof CounterMutation; Mutation result = ((CounterMutation) mutation).applyCounterMutation(); - responseHandler.response(null); + responseHandler.onResponse(null); sendToHintedReplicas(result, replicaPlan, responseHandler, localDataCenter, Stage.COUNTER_MUTATION); } }; @@ -1764,11 +1755,10 @@ public static class LocalReadRunnable extends DroppableRunnable { private final ReadCommand command; private final ReadCallback handler; - private final long start = System.nanoTime(); public LocalReadRunnable(ReadCommand command, ReadCallback handler) { - super(MessagingService.Verb.READ); + super(Verb.READ_REQ); this.command = command; this.handler = handler; } @@ -1777,7 +1767,7 @@ protected void runMayThrow() { try { - command.setMonitoringTime(constructionTime, false, verb.getTimeout(), DatabaseDescriptor.getSlowQueryTimeout()); + command.setMonitoringTime(approxCreationTimeNanos, false, verb.expiresAfterNanos(), DatabaseDescriptor.getSlowQueryTimeout(NANOSECONDS)); ReadResponse response; try (ReadExecutionController executionController = command.executionController(); @@ -1792,11 +1782,11 @@ protected void runMayThrow() } else { - MessagingService.instance().incrementDroppedMessages(verb, System.currentTimeMillis() - constructionTime); + MessagingService.instance().metrics.recordSelfDroppedMessage(verb, MonotonicClock.approxTime.now() - approxCreationTimeNanos, NANOSECONDS); handler.onFailure(FBUtilities.getBroadcastAddressAndPort(), RequestFailureReason.UNKNOWN); } - MessagingService.instance().addLatency(FBUtilities.getBroadcastAddressAndPort(), TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - start)); + MessagingService.instance().latencySubscribers.add(FBUtilities.getBroadcastAddressAndPort(), MonotonicClock.approxTime.now() - approxCreationTimeNanos, NANOSECONDS); } catch (Throwable t) { @@ -2075,10 +2065,8 @@ private SingleRangeResponse query(ReplicaPlan.ForRangeRead replicaPlan, boolean { Tracing.trace("Enqueuing request to {}", replica); ReadCommand command = replica.isFull() ? rangeCommand : rangeCommand.copyAsTransientQuery(replica); - MessageOut message = command.createMessage(); - if (command.isTrackingRepairedStatus() && replica.isFull()) - message = message.withParameter(ParameterType.TRACK_REPAIRED_DATA, MessagingService.ONE_BYTE); - MessagingService.instance().sendRRWithFailure(message, replica.endpoint(), handler); + Message message = command.createMessage(command.isTrackingRepairedStatus() && replica.isFull()); + MessagingService.instance().sendWithCallback(message, replica.endpoint(), handler); } } @@ -2181,29 +2169,21 @@ public static Map> describeSchemaVersions(boolean withPort) final Set liveHosts = Gossiper.instance.getLiveMembers(); final CountDownLatch latch = new CountDownLatch(liveHosts.size()); - IAsyncCallback cb = new IAsyncCallback() + RequestCallback cb = message -> { - public void response(MessageIn message) - { - // record the response from the remote node. - versions.put(message.from, message.payload); - latch.countDown(); - } - - public boolean isLatencyForSnitch() - { - return false; - } + // record the response from the remote node. + versions.put(message.from(), message.payload); + latch.countDown(); }; // an empty message acts as a request to the SchemaVersionVerbHandler. - MessageOut message = new MessageOut(MessagingService.Verb.SCHEMA_CHECK); + Message message = Message.out(Verb.SCHEMA_VERSION_REQ, noPayload); for (InetAddressAndPort endpoint : liveHosts) - MessagingService.instance().sendRR(message, endpoint, cb); + MessagingService.instance().sendWithCallback(message, endpoint, cb); try { // wait for as long as possible. timeout-1s if possible. - latch.await(DatabaseDescriptor.getRpcTimeout(), TimeUnit.MILLISECONDS); + latch.await(DatabaseDescriptor.getRpcTimeout(NANOSECONDS), NANOSECONDS); } catch (InterruptedException ex) { @@ -2385,10 +2365,9 @@ public static void truncateBlocking(String keyspace, String cfname) throws Unava // Send out the truncate calls and track the responses with the callbacks. Tracing.trace("Enqueuing truncate messages to hosts {}", allEndpoints); - final Truncation truncation = new Truncation(keyspace, cfname); - MessageOut message = truncation.createMessage(); + Message message = Message.out(TRUNCATE_REQ, new TruncateRequest(keyspace, cfname)); for (InetAddressAndPort endpoint : allEndpoints) - MessagingService.instance().sendRR(message, endpoint, responseHandler); + MessagingService.instance().sendWithCallback(message, endpoint, responseHandler); // Wait for all try @@ -2430,9 +2409,9 @@ public ViewWriteMetricsWrapped(AbstractWriteResponseHandler writeHand viewWriteMetrics.viewReplicasAttempted.inc(candidateReplicaCount()); } - public void response(MessageIn msg) + public void onResponse(Message msg) { - super.response(msg); + super.onResponse(msg); viewWriteMetrics.viewReplicasSuccess.inc(); } } @@ -2442,21 +2421,23 @@ public void response(MessageIn msg) */ private static abstract class DroppableRunnable implements Runnable { - final long constructionTime; - final MessagingService.Verb verb; + final long approxCreationTimeNanos; + final Verb verb; - public DroppableRunnable(MessagingService.Verb verb) + public DroppableRunnable(Verb verb) { - this.constructionTime = System.currentTimeMillis(); + this.approxCreationTimeNanos = MonotonicClock.approxTime.now(); this.verb = verb; } public final void run() { - long timeTaken = System.currentTimeMillis() - constructionTime; - if (timeTaken > verb.getTimeout()) + long approxCurrentTimeNanos = MonotonicClock.approxTime.now(); + long expirationTimeNanos = verb.expiresAtNanos(approxCreationTimeNanos); + if (approxCurrentTimeNanos > expirationTimeNanos) { - MessagingService.instance().incrementDroppedMessages(verb, timeTaken); + long timeTakenNanos = approxCurrentTimeNanos - approxCreationTimeNanos; + MessagingService.instance().metrics.recordSelfDroppedMessage(verb, timeTakenNanos, NANOSECONDS); return; } try @@ -2478,32 +2459,24 @@ public final void run() */ private static abstract class LocalMutationRunnable implements Runnable { - private final long constructionTime = System.currentTimeMillis(); + private final long approxCreationTimeNanos = MonotonicClock.approxTime.now(); private final Replica localReplica; - private final Optional mutationOpt; - public LocalMutationRunnable(Replica localReplica, Optional mutationOpt) + LocalMutationRunnable(Replica localReplica) { this.localReplica = localReplica; - this.mutationOpt = mutationOpt; - } - - public LocalMutationRunnable(Replica localReplica) - { - this.localReplica = localReplica; - this.mutationOpt = Optional.empty(); } public final void run() { - final MessagingService.Verb verb = verb(); - long mutationTimeout = verb.getTimeout(); - long timeTaken = System.currentTimeMillis() - constructionTime; - if (timeTaken > mutationTimeout) + final Verb verb = verb(); + long nowNanos = MonotonicClock.approxTime.now(); + long expirationTimeNanos = verb.expiresAtNanos(approxCreationTimeNanos); + if (nowNanos > expirationTimeNanos) { - if (MessagingService.DROPPABLE_VERBS.contains(verb)) - MessagingService.instance().incrementDroppedMutations(mutationOpt, timeTaken); + long timeTakenNanos = nowNanos - approxCreationTimeNanos; + MessagingService.instance().metrics.recordSelfDroppedMessage(Verb.MUTATION_REQ, timeTakenNanos, NANOSECONDS); HintRunnable runnable = new HintRunnable(EndpointsForToken.of(localReplica.range().right, localReplica)) { @@ -2526,7 +2499,7 @@ protected void runMayThrow() throws Exception } } - abstract protected MessagingService.Verb verb(); + abstract protected Verb verb(); abstract protected void runMayThrow() throws Exception; } @@ -2634,7 +2607,7 @@ public void runMayThrow() validTargets.forEach(HintsService.instance.metrics::incrCreatedHints); // Notify the handler only for CL == ANY if (responseHandler != null && responseHandler.replicaPlan.consistencyLevel() == ConsistencyLevel.ANY) - responseHandler.response(null); + responseHandler.onResponse(null); } }; @@ -2649,25 +2622,25 @@ private static Future submitHint(HintRunnable runnable) return (Future) StageManager.getStage(Stage.MUTATION).submit(runnable); } - public Long getRpcTimeout() { return DatabaseDescriptor.getRpcTimeout(); } + public Long getRpcTimeout() { return DatabaseDescriptor.getRpcTimeout(MILLISECONDS); } public void setRpcTimeout(Long timeoutInMillis) { DatabaseDescriptor.setRpcTimeout(timeoutInMillis); } - public Long getReadRpcTimeout() { return DatabaseDescriptor.getReadRpcTimeout(); } + public Long getReadRpcTimeout() { return DatabaseDescriptor.getReadRpcTimeout(MILLISECONDS); } public void setReadRpcTimeout(Long timeoutInMillis) { DatabaseDescriptor.setReadRpcTimeout(timeoutInMillis); } - public Long getWriteRpcTimeout() { return DatabaseDescriptor.getWriteRpcTimeout(); } + public Long getWriteRpcTimeout() { return DatabaseDescriptor.getWriteRpcTimeout(MILLISECONDS); } public void setWriteRpcTimeout(Long timeoutInMillis) { DatabaseDescriptor.setWriteRpcTimeout(timeoutInMillis); } - public Long getCounterWriteRpcTimeout() { return DatabaseDescriptor.getCounterWriteRpcTimeout(); } + public Long getCounterWriteRpcTimeout() { return DatabaseDescriptor.getCounterWriteRpcTimeout(MILLISECONDS); } public void setCounterWriteRpcTimeout(Long timeoutInMillis) { DatabaseDescriptor.setCounterWriteRpcTimeout(timeoutInMillis); } - public Long getCasContentionTimeout() { return DatabaseDescriptor.getCasContentionTimeout(); } + public Long getCasContentionTimeout() { return DatabaseDescriptor.getCasContentionTimeout(MILLISECONDS); } public void setCasContentionTimeout(Long timeoutInMillis) { DatabaseDescriptor.setCasContentionTimeout(timeoutInMillis); } - public Long getRangeRpcTimeout() { return DatabaseDescriptor.getRangeRpcTimeout(); } + public Long getRangeRpcTimeout() { return DatabaseDescriptor.getRangeRpcTimeout(MILLISECONDS); } public void setRangeRpcTimeout(Long timeoutInMillis) { DatabaseDescriptor.setRangeRpcTimeout(timeoutInMillis); } - public Long getTruncateRpcTimeout() { return DatabaseDescriptor.getTruncateRpcTimeout(); } + public Long getTruncateRpcTimeout() { return DatabaseDescriptor.getTruncateRpcTimeout(MILLISECONDS); } public void setTruncateRpcTimeout(Long timeoutInMillis) { DatabaseDescriptor.setTruncateRpcTimeout(timeoutInMillis); } public Long getNativeTransportMaxConcurrentConnections() { return DatabaseDescriptor.getNativeTransportMaxConcurrentConnections(); } @@ -2739,13 +2712,13 @@ public void stopFullQueryLogger() AuditLogManager.getInstance().disableFQL(); } + @Deprecated public int getOtcBacklogExpirationInterval() { - return DatabaseDescriptor.getOtcBacklogExpirationInterval(); + return 0; } - public void setOtcBacklogExpirationInterval(int intervalInMillis) { - DatabaseDescriptor.setOtcBacklogExpirationInterval(intervalInMillis); - } + @Deprecated + public void setOtcBacklogExpirationInterval(int intervalInMillis) { } @Override public void enableRepairedDataTrackingForRangeReads() diff --git a/src/java/org/apache/cassandra/service/StorageProxyMBean.java b/src/java/org/apache/cassandra/service/StorageProxyMBean.java index 95f5f26afff8..08b5cbd3b479 100644 --- a/src/java/org/apache/cassandra/service/StorageProxyMBean.java +++ b/src/java/org/apache/cassandra/service/StorageProxyMBean.java @@ -63,7 +63,9 @@ public interface StorageProxyMBean public long getReadRepairRepairedBlocking(); public long getReadRepairRepairedBackground(); + @Deprecated public int getOtcBacklogExpirationInterval(); + @Deprecated public void setOtcBacklogExpirationInterval(int intervalInMillis); /** Returns each live node's schema version */ diff --git a/src/java/org/apache/cassandra/service/StorageService.java b/src/java/org/apache/cassandra/service/StorageService.java index eade7ddd4876..a8ae42c2f0c7 100644 --- a/src/java/org/apache/cassandra/service/StorageService.java +++ b/src/java/org/apache/cassandra/service/StorageService.java @@ -56,8 +56,6 @@ import org.apache.cassandra.audit.AuditLogOptions; import org.apache.cassandra.auth.AuthKeyspace; import org.apache.cassandra.auth.AuthSchemaChangeListener; -import org.apache.cassandra.batchlog.BatchRemoveVerbHandler; -import org.apache.cassandra.batchlog.BatchStoreVerbHandler; import org.apache.cassandra.batchlog.BatchlogManager; import org.apache.cassandra.concurrent.ExecutorLocals; import org.apache.cassandra.concurrent.NamedThreadFactory; @@ -78,7 +76,6 @@ import org.apache.cassandra.dht.Token.TokenFactory; import org.apache.cassandra.exceptions.*; import org.apache.cassandra.gms.*; -import org.apache.cassandra.hints.HintVerbHandler; import org.apache.cassandra.hints.HintsService; import org.apache.cassandra.io.sstable.SSTableLoader; import org.apache.cassandra.io.util.FileUtils; @@ -93,16 +90,9 @@ import org.apache.cassandra.schema.ReplicationParams; import org.apache.cassandra.schema.Schema; import org.apache.cassandra.schema.SchemaConstants; -import org.apache.cassandra.schema.SchemaPullVerbHandler; -import org.apache.cassandra.schema.SchemaPushVerbHandler; -import org.apache.cassandra.schema.SchemaVersionVerbHandler; import org.apache.cassandra.schema.TableMetadata; import org.apache.cassandra.schema.TableMetadataRef; import org.apache.cassandra.schema.ViewMetadata; -import org.apache.cassandra.repair.RepairMessageVerbHandler; -import org.apache.cassandra.service.paxos.CommitVerbHandler; -import org.apache.cassandra.service.paxos.PrepareVerbHandler; -import org.apache.cassandra.service.paxos.ProposeVerbHandler; import org.apache.cassandra.streaming.*; import org.apache.cassandra.tracing.TraceKeyspace; import org.apache.cassandra.transport.ProtocolVersion; @@ -110,15 +100,21 @@ import org.apache.cassandra.utils.logging.LoggingSupportFactory; import org.apache.cassandra.utils.progress.ProgressEvent; import org.apache.cassandra.utils.progress.ProgressEventType; +import org.apache.cassandra.utils.progress.ProgressListener; import org.apache.cassandra.utils.progress.jmx.JMXBroadcastExecutor; import org.apache.cassandra.utils.progress.jmx.JMXProgressSupport; import static com.google.common.collect.Iterables.transform; import static com.google.common.collect.Iterables.tryFind; import static java.util.Arrays.asList; +import static java.util.concurrent.TimeUnit.MINUTES; +import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static java.util.concurrent.TimeUnit.NANOSECONDS; import static java.util.stream.Collectors.toList; import static org.apache.cassandra.index.SecondaryIndexManager.getIndexName; import static org.apache.cassandra.index.SecondaryIndexManager.isIndexColumnFamily; +import static org.apache.cassandra.net.NoPayload.noPayload; +import static org.apache.cassandra.net.Verb.REPLICATION_DONE_REQ; /** * This abstraction contains the token/identifier of this node @@ -283,44 +279,6 @@ public StorageService() jmxObjectName = "org.apache.cassandra.db:type=StorageService"; MBeanWrapper.instance.registerMBean(this, jmxObjectName); MBeanWrapper.instance.registerMBean(StreamManager.instance, StreamManager.OBJECT_NAME); - - ReadCommandVerbHandler readHandler = new ReadCommandVerbHandler(); - - /* register the verb handlers */ - MessagingService.instance().registerVerbHandlers(MessagingService.Verb.MUTATION, new MutationVerbHandler()); - MessagingService.instance().registerVerbHandlers(MessagingService.Verb.READ_REPAIR, new ReadRepairVerbHandler()); - MessagingService.instance().registerVerbHandlers(MessagingService.Verb.READ, readHandler); - MessagingService.instance().registerVerbHandlers(MessagingService.Verb.RANGE_SLICE, readHandler); - MessagingService.instance().registerVerbHandlers(MessagingService.Verb.PAGED_RANGE, readHandler); - MessagingService.instance().registerVerbHandlers(MessagingService.Verb.COUNTER_MUTATION, new CounterMutationVerbHandler()); - MessagingService.instance().registerVerbHandlers(MessagingService.Verb.TRUNCATE, new TruncateVerbHandler()); - MessagingService.instance().registerVerbHandlers(MessagingService.Verb.PAXOS_PREPARE, new PrepareVerbHandler()); - MessagingService.instance().registerVerbHandlers(MessagingService.Verb.PAXOS_PROPOSE, new ProposeVerbHandler()); - MessagingService.instance().registerVerbHandlers(MessagingService.Verb.PAXOS_COMMIT, new CommitVerbHandler()); - MessagingService.instance().registerVerbHandlers(MessagingService.Verb.HINT, new HintVerbHandler()); - - // see BootStrapper for a summary of how the bootstrap verbs interact - MessagingService.instance().registerVerbHandlers(MessagingService.Verb.REPLICATION_FINISHED, new ReplicationFinishedVerbHandler()); - MessagingService.instance().registerVerbHandlers(MessagingService.Verb.REQUEST_RESPONSE, new ResponseVerbHandler()); - MessagingService.instance().registerVerbHandlers(MessagingService.Verb.INTERNAL_RESPONSE, new ResponseVerbHandler()); - MessagingService.instance().registerVerbHandlers(MessagingService.Verb.REPAIR_MESSAGE, new RepairMessageVerbHandler()); - MessagingService.instance().registerVerbHandlers(MessagingService.Verb.GOSSIP_SHUTDOWN, new GossipShutdownVerbHandler()); - - MessagingService.instance().registerVerbHandlers(MessagingService.Verb.GOSSIP_DIGEST_SYN, new GossipDigestSynVerbHandler()); - MessagingService.instance().registerVerbHandlers(MessagingService.Verb.GOSSIP_DIGEST_ACK, new GossipDigestAckVerbHandler()); - MessagingService.instance().registerVerbHandlers(MessagingService.Verb.GOSSIP_DIGEST_ACK2, new GossipDigestAck2VerbHandler()); - - MessagingService.instance().registerVerbHandlers(MessagingService.Verb.DEFINITIONS_UPDATE, new SchemaPushVerbHandler()); - MessagingService.instance().registerVerbHandlers(MessagingService.Verb.SCHEMA_CHECK, new SchemaVersionVerbHandler()); - MessagingService.instance().registerVerbHandlers(MessagingService.Verb.MIGRATION_REQUEST, new SchemaPullVerbHandler()); - - MessagingService.instance().registerVerbHandlers(MessagingService.Verb.SNAPSHOT, new SnapshotVerbHandler()); - MessagingService.instance().registerVerbHandlers(MessagingService.Verb.ECHO, new EchoVerbHandler()); - - MessagingService.instance().registerVerbHandlers(MessagingService.Verb.BATCH_STORE, new BatchStoreVerbHandler()); - MessagingService.instance().registerVerbHandlers(MessagingService.Verb.BATCH_REMOVE, new BatchRemoveVerbHandler()); - - MessagingService.instance().registerVerbHandlers(MessagingService.Verb.PING, new PingVerbHandler()); } public void registerDaemon(CassandraDaemon daemon) @@ -625,8 +583,7 @@ public void unsafeInitialize() throws ConfigurationException Gossiper.instance.register(this); Gossiper.instance.start((int) (System.currentTimeMillis() / 1000)); // needed for node-ring gathering. Gossiper.instance.addLocalApplicationState(ApplicationState.NET_VERSION, valueFactory.networkVersion()); - if (!MessagingService.instance().isListening()) - MessagingService.instance().listen(); + MessagingService.instance().listen(); } public void populateTokenMetadata() @@ -808,8 +765,7 @@ private void prepareToJoin() throws ConfigurationException if (DatabaseDescriptor.getReplaceTokens().size() > 0 || DatabaseDescriptor.getReplaceNode() != null) throw new RuntimeException("Replace method removed; use cassandra.replace_address instead"); - if (!MessagingService.instance().isListening()) - MessagingService.instance().listen(); + MessagingService.instance().listen(); UUID localHostId = SystemKeyspace.getLocalHostId(); @@ -1356,7 +1312,7 @@ public void setRpcTimeout(long value) public long getRpcTimeout() { - return DatabaseDescriptor.getRpcTimeout(); + return DatabaseDescriptor.getRpcTimeout(MILLISECONDS); } public void setReadRpcTimeout(long value) @@ -1367,7 +1323,7 @@ public void setReadRpcTimeout(long value) public long getReadRpcTimeout() { - return DatabaseDescriptor.getReadRpcTimeout(); + return DatabaseDescriptor.getReadRpcTimeout(MILLISECONDS); } public void setRangeRpcTimeout(long value) @@ -1378,7 +1334,7 @@ public void setRangeRpcTimeout(long value) public long getRangeRpcTimeout() { - return DatabaseDescriptor.getRangeRpcTimeout(); + return DatabaseDescriptor.getRangeRpcTimeout(MILLISECONDS); } public void setWriteRpcTimeout(long value) @@ -1389,7 +1345,7 @@ public void setWriteRpcTimeout(long value) public long getWriteRpcTimeout() { - return DatabaseDescriptor.getWriteRpcTimeout(); + return DatabaseDescriptor.getWriteRpcTimeout(MILLISECONDS); } public void setInternodeTcpConnectTimeoutInMS(int value) @@ -1422,7 +1378,7 @@ public void setCounterWriteRpcTimeout(long value) public long getCounterWriteRpcTimeout() { - return DatabaseDescriptor.getCounterWriteRpcTimeout(); + return DatabaseDescriptor.getCounterWriteRpcTimeout(MILLISECONDS); } public void setCasContentionTimeout(long value) @@ -1433,7 +1389,7 @@ public void setCasContentionTimeout(long value) public long getCasContentionTimeout() { - return DatabaseDescriptor.getCasContentionTimeout(); + return DatabaseDescriptor.getCasContentionTimeout(MILLISECONDS); } public void setTruncateRpcTimeout(long value) @@ -1444,7 +1400,7 @@ public void setTruncateRpcTimeout(long value) public long getTruncateRpcTimeout() { - return DatabaseDescriptor.getTruncateRpcTimeout(); + return DatabaseDescriptor.getTruncateRpcTimeout(MILLISECONDS); } public void setStreamThroughputMbPerSec(int value) @@ -1581,7 +1537,7 @@ private boolean bootstrap(final Collection tokens) valueFactory.bootstrapping(tokens))); Gossiper.instance.addLocalApplicationStates(states); setMode(Mode.JOINING, "sleeping " + RING_DELAY + " ms for pending range setup", true); - Uninterruptibles.sleepUninterruptibly(RING_DELAY, TimeUnit.MILLISECONDS); + Uninterruptibles.sleepUninterruptibly(RING_DELAY, MILLISECONDS); } else { @@ -2212,7 +2168,7 @@ private void updateNetVersion(InetAddressAndPort endpoint, VersionedValue value) { try { - MessagingService.instance().setVersion(endpoint, Integer.parseInt(value.value)); + MessagingService.instance().versions.set(endpoint, Integer.parseInt(value.value)); } catch (NumberFormatException e) { @@ -2457,6 +2413,84 @@ private void handleStateBootreplacing(InetAddressAndPort newNode, String[] piece tokenMetadata.updateHostId(Gossiper.instance.getHostId(newNode), newNode); } + private void ensureUpToDateTokenMetadata(String status, InetAddressAndPort endpoint) + { + Set tokens = new TreeSet<>(getTokensFor(endpoint)); + + if (logger.isDebugEnabled()) + logger.debug("Node {} state {}, tokens {}", endpoint, status, tokens); + + // If the node is previously unknown or tokens do not match, update tokenmetadata to + // have this node as 'normal' (it must have been using this token before the + // leave). This way we'll get pending ranges right. + if (!tokenMetadata.isMember(endpoint)) + { + logger.info("Node {} state jump to {}", endpoint, status); + updateTokenMetadata(endpoint, tokens); + } + else if (!tokens.equals(new TreeSet<>(tokenMetadata.getTokens(endpoint)))) + { + logger.warn("Node {} '{}' token mismatch. Long network partition?", endpoint, status); + updateTokenMetadata(endpoint, tokens); + } + } + + private void updateTokenMetadata(InetAddressAndPort endpoint, Iterable tokens) + { + updateTokenMetadata(endpoint, tokens, new HashSet<>()); + } + + private void updateTokenMetadata(InetAddressAndPort endpoint, Iterable tokens, Set endpointsToRemove) + { + Set tokensToUpdateInMetadata = new HashSet<>(); + Set tokensToUpdateInSystemKeyspace = new HashSet<>(); + + for (final Token token : tokens) + { + // we don't want to update if this node is responsible for the token and it has a later startup time than endpoint. + InetAddressAndPort currentOwner = tokenMetadata.getEndpoint(token); + if (currentOwner == null) + { + logger.debug("New node {} at token {}", endpoint, token); + tokensToUpdateInMetadata.add(token); + tokensToUpdateInSystemKeyspace.add(token); + } + else if (endpoint.equals(currentOwner)) + { + // set state back to normal, since the node may have tried to leave, but failed and is now back up + tokensToUpdateInMetadata.add(token); + tokensToUpdateInSystemKeyspace.add(token); + } + else if (Gossiper.instance.compareEndpointStartup(endpoint, currentOwner) > 0) + { + tokensToUpdateInMetadata.add(token); + tokensToUpdateInSystemKeyspace.add(token); + + // currentOwner is no longer current, endpoint is. Keep track of these moves, because when + // a host no longer has any tokens, we'll want to remove it. + Multimap epToTokenCopy = getTokenMetadata().getEndpointToTokenMapForReading(); + epToTokenCopy.get(currentOwner).remove(token); + if (epToTokenCopy.get(currentOwner).isEmpty()) + endpointsToRemove.add(currentOwner); + + logger.info("Nodes {} and {} have the same token {}. {} is the new owner", endpoint, currentOwner, token, endpoint); + } + else + { + logger.info("Nodes () and {} have the same token {}. Ignoring {}", endpoint, currentOwner, token, endpoint); + } + } + + tokenMetadata.updateNormalTokens(tokensToUpdateInMetadata, endpoint); + for (InetAddressAndPort ep : endpointsToRemove) + { + removeEndpoint(ep); + if (replacing && ep.equals(DatabaseDescriptor.getReplaceAddress())) + Gossiper.instance.replacementQuarantine(ep); // quarantine locally longer than normally; see CASSANDRA-8260 + } + if (!tokensToUpdateInSystemKeyspace.isEmpty()) + SystemKeyspace.updateTokens(endpoint, tokensToUpdateInSystemKeyspace); + } /** * Handle node move to normal state. That is, node is entering token ring and participating * in reads. @@ -2466,8 +2500,6 @@ private void handleStateBootreplacing(InetAddressAndPort newNode, String[] piece private void handleStateNormal(final InetAddressAndPort endpoint, final String status) { Collection tokens = getTokensFor(endpoint); - Set tokensToUpdateInMetadata = new HashSet<>(); - Set tokensToUpdateInSystemKeyspace = new HashSet<>(); Set endpointsToRemove = new HashSet<>(); if (logger.isDebugEnabled()) @@ -2535,62 +2567,11 @@ else if (Gossiper.instance.compareEndpointStartup(endpoint, existing) > 0) tokenMetadata.updateHostId(hostId, endpoint); } - for (final Token token : tokens) - { - // we don't want to update if this node is responsible for the token and it has a later startup time than endpoint. - InetAddressAndPort currentOwner = tokenMetadata.getEndpoint(token); - if (currentOwner == null) - { - logger.debug("New node {} at token {}", endpoint, token); - tokensToUpdateInMetadata.add(token); - tokensToUpdateInSystemKeyspace.add(token); - } - else if (endpoint.equals(currentOwner)) - { - // set state back to normal, since the node may have tried to leave, but failed and is now back up - tokensToUpdateInMetadata.add(token); - tokensToUpdateInSystemKeyspace.add(token); - } - else if (Gossiper.instance.compareEndpointStartup(endpoint, currentOwner) > 0) - { - tokensToUpdateInMetadata.add(token); - tokensToUpdateInSystemKeyspace.add(token); - - // currentOwner is no longer current, endpoint is. Keep track of these moves, because when - // a host no longer has any tokens, we'll want to remove it. - Multimap epToTokenCopy = getTokenMetadata().getEndpointToTokenMapForReading(); - epToTokenCopy.get(currentOwner).remove(token); - if (epToTokenCopy.get(currentOwner).size() < 1) - endpointsToRemove.add(currentOwner); - - logger.info("Nodes {} and {} have the same token {}. {} is the new owner", - endpoint, - currentOwner, - token, - endpoint); - } - else - { - logger.info("Nodes {} and {} have the same token {}. Ignoring {}", - endpoint, - currentOwner, - token, - endpoint); - } - } - // capture because updateNormalTokens clears moving and member status boolean isMember = tokenMetadata.isMember(endpoint); boolean isMoving = tokenMetadata.isMoving(endpoint); - tokenMetadata.updateNormalTokens(tokensToUpdateInMetadata, endpoint); - for (InetAddressAndPort ep : endpointsToRemove) - { - removeEndpoint(ep); - if (replacing && DatabaseDescriptor.getReplaceAddress().equals(ep)) - Gossiper.instance.replacementQuarantine(ep); // quarantine locally longer than normally; see CASSANDRA-8260 - } - if (!tokensToUpdateInSystemKeyspace.isEmpty()) - SystemKeyspace.updateTokens(endpoint, tokensToUpdateInSystemKeyspace); + + updateTokenMetadata(endpoint, tokens, endpointsToRemove); if (isMoving || operationMode == Mode.MOVING) { @@ -2612,24 +2593,11 @@ else if (!isMember) // prior to this, the node was not a member */ private void handleStateLeaving(InetAddressAndPort endpoint) { - Collection tokens = getTokensFor(endpoint); - - if (logger.isDebugEnabled()) - logger.debug("Node {} state leaving, tokens {}", endpoint, tokens); - // If the node is previously unknown or tokens do not match, update tokenmetadata to // have this node as 'normal' (it must have been using this token before the // leave). This way we'll get pending ranges right. - if (!tokenMetadata.isMember(endpoint)) - { - logger.info("Node {} state jump to leaving", endpoint); - tokenMetadata.updateNormalTokens(tokens, endpoint); - } - else if (!tokenMetadata.getTokens(endpoint).containsAll(tokens)) - { - logger.warn("Node {} 'leaving' token mismatch. Long network partition?", endpoint); - tokenMetadata.updateNormalTokens(tokens, endpoint); - } + + ensureUpToDateTokenMetadata(VersionedValue.STATUS_LEAVING, endpoint); // at this point the endpoint is certainly a member with this token, so let's proceed // normally @@ -2662,6 +2630,8 @@ private void handleStateLeft(InetAddressAndPort endpoint, String[] pieces) */ private void handleStateMoving(InetAddressAndPort endpoint, String[] pieces) { + ensureUpToDateTokenMetadata(VersionedValue.STATUS_MOVING, endpoint); + assert pieces.length >= 2; Token token = getTokenFactory().fromString(pieces[1]); @@ -2707,6 +2677,8 @@ private void handleStateRemoving(InetAddressAndPort endpoint, String[] pieces) } else if (VersionedValue.REMOVING_TOKEN.equals(state)) { + ensureUpToDateTokenMetadata(state, endpoint); + if (logger.isDebugEnabled()) logger.debug("Tokens {} removed manually (endpoint was {})", removeTokens, endpoint); @@ -2738,8 +2710,8 @@ private void excise(Collection tokens, InetAddressAndPort endpoint) { // enough time for writes to expire and MessagingService timeout reporter callback to fire, which is where // hints are mostly written from - using getMinRpcTimeout() / 2 for the interval. - long delay = DatabaseDescriptor.getMinRpcTimeout() + DatabaseDescriptor.getWriteRpcTimeout(); - ScheduledExecutors.optionalTasks.schedule(() -> HintsService.instance.excise(hostId), delay, TimeUnit.MILLISECONDS); + long delay = DatabaseDescriptor.getMinRpcTimeout(MILLISECONDS) + DatabaseDescriptor.getWriteRpcTimeout(MILLISECONDS); + ScheduledExecutors.optionalTasks.schedule(() -> HintsService.instance.excise(hostId), delay, MILLISECONDS); } removeEndpoint(endpoint); @@ -2843,22 +2815,22 @@ private Multimap getNewSourceReplicas(String k private void sendReplicationNotification(InetAddressAndPort remote) { // notify the remote token - MessageOut msg = new MessageOut(MessagingService.Verb.REPLICATION_FINISHED); + Message msg = Message.out(REPLICATION_DONE_REQ, noPayload); IFailureDetector failureDetector = FailureDetector.instance; if (logger.isDebugEnabled()) logger.debug("Notifying {} of replication completion\n", remote); while (failureDetector.isAlive(remote)) { - AsyncOneResponse iar = MessagingService.instance().sendRR(msg, remote); - try - { - iar.get(DatabaseDescriptor.getRpcTimeout(), TimeUnit.MILLISECONDS); - return; // done - } - catch(TimeoutException e) - { - // try again - } + AsyncOneResponse ior = new AsyncOneResponse(); + MessagingService.instance().sendWithCallback(msg, remote, ior); + + if (!ior.awaitUninterruptibly(DatabaseDescriptor.getRpcTimeout(NANOSECONDS), NANOSECONDS)) + continue; // try again if we timeout + + if (!ior.isSuccess()) + throw new AssertionError(ior.cause()); + + return; } } @@ -3082,7 +3054,9 @@ public void onRemove(InetAddressAndPort endpoint) public void onDead(InetAddressAndPort endpoint, EndpointState state) { - MessagingService.instance().convict(endpoint); + // interrupt any outbound connection; if the node is failing and we cannot reconnect, + // this will rapidly lower the number of bytes we are willing to queue to the node + MessagingService.instance().interruptOutbound(endpoint); notifyDown(endpoint); } @@ -3760,11 +3734,16 @@ public void forceKeyspaceFlush(String keyspaceName, String... tableNames) throws for (ColumnFamilyStore cfStore : getValidColumnFamilies(true, false, keyspaceName, tableNames)) { logger.debug("Forcing flush on keyspace {}, CF {}", keyspaceName, cfStore.name); - cfStore.forceBlockingFlush(); + cfStore.forceBlockingFlushToSSTable(); } } public int repairAsync(String keyspace, Map repairSpec) + { + return repair(keyspace, repairSpec, Collections.emptyList()).left; + } + + public Pair> repair(String keyspace, Map repairSpec, List listeners) { RepairOption option = RepairOption.parse(repairSpec, tokenMetadata.partitioner); // if ranges are not specified @@ -3787,11 +3766,10 @@ else if (option.isInLocalDCOnly()) } } if (option.getRanges().isEmpty() || Keyspace.open(keyspace).getReplicationStrategy().getReplicationFactor().allReplicas < 2) - return 0; + return Pair.create(0, Futures.immediateFuture(null)); int cmd = nextRepairCommand.incrementAndGet(); - ActiveRepairService.repairCommandExecutor.execute(createRepairTask(cmd, keyspace, option)); - return cmd; + return Pair.create(cmd, ActiveRepairService.repairCommandExecutor.submit(createRepairTask(cmd, keyspace, option, listeners))); } /** @@ -3837,7 +3815,7 @@ public TokenFactory getTokenFactory() return tokenMetadata.partitioner.getTokenFactory(); } - private FutureTask createRepairTask(final int cmd, final String keyspace, final RepairOption options) + private FutureTask createRepairTask(final int cmd, final String keyspace, final RepairOption options, List listeners) { if (!options.getDataCenters().isEmpty() && !options.getDataCenters().contains(DatabaseDescriptor.getLocalDataCenter())) { @@ -3846,6 +3824,9 @@ private FutureTask createRepairTask(final int cmd, final String keyspace RepairRunnable task = new RepairRunnable(this, cmd, options, keyspace); task.addProgressListener(progressSupport); + for (ProgressListener listener : listeners) + task.addProgressListener(listener); + if (options.isTraced()) { Runnable r = () -> @@ -4233,7 +4214,7 @@ private void leaveRing() Gossiper.instance.addLocalApplicationState(ApplicationState.STATUS, valueFactory.left(getLocalTokens(),Gossiper.computeExpireTime())); int delay = Math.max(RING_DELAY, Gossiper.intervalInMillis * 2); logger.info("Announcing that I have left the ring for {}ms", delay); - Uninterruptibles.sleepUninterruptibly(delay, TimeUnit.MILLISECONDS); + Uninterruptibles.sleepUninterruptibly(delay, MILLISECONDS); } private void unbootstrap(Runnable onFinish) throws ExecutionException, InterruptedException @@ -4362,7 +4343,7 @@ private void move(Token newToken) throws IOException setMode(Mode.MOVING, String.format("Moving %s from %s to %s.", localAddress, getLocalTokens().iterator().next(), newToken), true); setMode(Mode.MOVING, String.format("Sleeping %s ms before start streaming/fetching ranges", RING_DELAY), true); - Uninterruptibles.sleepUninterruptibly(RING_DELAY, TimeUnit.MILLISECONDS); + Uninterruptibles.sleepUninterruptibly(RING_DELAY, MILLISECONDS); RangeRelocator relocator = new RangeRelocator(Collections.singleton(newToken), keyspacesToProcess, tokenMetadata); relocator.calculateToFromStreams(); @@ -4518,10 +4499,10 @@ public void removeNode(String hostIdString) // kick off streaming commands restoreReplicaCount(endpoint, myAddress); - // wait for ReplicationFinishedVerbHandler to signal we're done + // wait for ReplicationDoneVerbHandler to signal we're done while (!replicatingNodes.isEmpty()) { - Uninterruptibles.sleepUninterruptibly(100, TimeUnit.MILLISECONDS); + Uninterruptibles.sleepUninterruptibly(100, MILLISECONDS); } excise(tokens, endpoint); @@ -4617,7 +4598,16 @@ protected synchronized void drain(boolean isFinalShutdown) throws IOException, I { setMode(Mode.DRAINING, "starting drain process", !isFinalShutdown); - BatchlogManager.instance.shutdown(); + try + { + /* not clear this is reasonable time, but propagated from prior embedded behaviour */ + BatchlogManager.instance.shutdownAndWait(1L, MINUTES); + } + catch (TimeoutException t) + { + logger.error("Batchlog manager timed out shutting down", t); + } + HintsService.instance.pauseDispatch(); if (daemon != null) @@ -4661,7 +4651,7 @@ protected synchronized void drain(boolean isFinalShutdown) throws IOException, I for (Keyspace keyspace : Keyspace.nonSystem()) { for (ColumnFamilyStore cfs : keyspace.getColumnFamilyStores()) - flushes.add(cfs.forceFlush()); + flushes.add(cfs.forceFlushToSSTable()); } // wait for the flushes. // TODO this is a godawful way to track progress, since they flush in parallel. a long one could @@ -4693,7 +4683,7 @@ protected synchronized void drain(boolean isFinalShutdown) throws IOException, I for (Keyspace keyspace : Keyspace.system()) { for (ColumnFamilyStore cfs : keyspace.getColumnFamilyStores()) - flushes.add(cfs.forceFlush()); + flushes.add(cfs.forceFlushToSSTable()); } FBUtilities.waitOnFutures(flushes); @@ -4710,7 +4700,7 @@ protected synchronized void drain(boolean isFinalShutdown) throws IOException, I // wait for miscellaneous tasks like sstable and commitlog segment deletion ScheduledExecutors.nonPeriodicTasks.shutdown(); - if (!ScheduledExecutors.nonPeriodicTasks.awaitTermination(1, TimeUnit.MINUTES)) + if (!ScheduledExecutors.nonPeriodicTasks.awaitTermination(1, MINUTES)) logger.warn("Failed to wait for non periodic tasks to shutdown"); ColumnFamilyStore.shutdownPostFlushExecutor(); @@ -5250,7 +5240,7 @@ public Map> samplePartitions(int durationMillis, int table.beginLocalSampling(sampler, capacity, durationMillis); } } - Uninterruptibles.sleepUninterruptibly(durationMillis, TimeUnit.MILLISECONDS); + Uninterruptibles.sleepUninterruptibly(durationMillis, MILLISECONDS); for (String sampler : samplers) { @@ -5469,4 +5459,13 @@ public void setCorruptedTombstoneStrategy(String strategy) DatabaseDescriptor.setCorruptedTombstoneStrategy(Config.CorruptedTombstoneStrategy.valueOf(strategy)); logger.info("Setting corrupted tombstone strategy to {}", strategy); } + + @VisibleForTesting + public void shutdownServer() + { + if (drainOnShutdown != null) + { + Runtime.getRuntime().removeShutdownHook(drainOnShutdown); + } + } } diff --git a/src/java/org/apache/cassandra/service/TruncateResponseHandler.java b/src/java/org/apache/cassandra/service/TruncateResponseHandler.java index cce8ecc111eb..bcd7426fd04e 100644 --- a/src/java/org/apache/cassandra/service/TruncateResponseHandler.java +++ b/src/java/org/apache/cassandra/service/TruncateResponseHandler.java @@ -17,7 +17,6 @@ */ package org.apache.cassandra.service; -import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicInteger; @@ -25,11 +24,13 @@ import org.slf4j.LoggerFactory; import org.apache.cassandra.config.DatabaseDescriptor; -import org.apache.cassandra.net.IAsyncCallback; -import org.apache.cassandra.net.MessageIn; +import org.apache.cassandra.net.RequestCallback; +import org.apache.cassandra.net.Message; import org.apache.cassandra.utils.concurrent.SimpleCondition; -public class TruncateResponseHandler implements IAsyncCallback +import static java.util.concurrent.TimeUnit.NANOSECONDS; + +public class TruncateResponseHandler implements RequestCallback { protected static final Logger logger = LoggerFactory.getLogger(TruncateResponseHandler.class); protected final SimpleCondition condition = new SimpleCondition(); @@ -49,11 +50,11 @@ public TruncateResponseHandler(int responseCount) public void get() throws TimeoutException { - long timeout = TimeUnit.MILLISECONDS.toNanos(DatabaseDescriptor.getTruncateRpcTimeout()) - (System.nanoTime() - start); + long timeoutNanos = DatabaseDescriptor.getTruncateRpcTimeout(NANOSECONDS) - (System.nanoTime() - start); boolean success; try { - success = condition.await(timeout, TimeUnit.NANOSECONDS); // TODO truncate needs a much longer timeout + success = condition.await(timeoutNanos, NANOSECONDS); // TODO truncate needs a much longer timeout } catch (InterruptedException ex) { @@ -66,15 +67,10 @@ public void get() throws TimeoutException } } - public void response(MessageIn message) + public void onResponse(Message message) { responses.incrementAndGet(); if (responses.get() >= responseCount) condition.signalAll(); } - - public boolean isLatencyForSnitch() - { - return false; - } } diff --git a/src/java/org/apache/cassandra/service/WriteResponseHandler.java b/src/java/org/apache/cassandra/service/WriteResponseHandler.java index f9bfedf007fc..94f5a80729a7 100644 --- a/src/java/org/apache/cassandra/service/WriteResponseHandler.java +++ b/src/java/org/apache/cassandra/service/WriteResponseHandler.java @@ -23,7 +23,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import org.apache.cassandra.net.MessageIn; +import org.apache.cassandra.net.Message; import org.apache.cassandra.db.WriteType; /** @@ -51,7 +51,7 @@ public WriteResponseHandler(ReplicaPlan.ForTokenWrite replicaPlan, WriteType wri this(replicaPlan, null, writeType, queryStartNanoTime); } - public void response(MessageIn m) + public void onResponse(Message m) { if (responsesUpdater.decrementAndGet(this) == 0) signal(); @@ -65,9 +65,4 @@ protected int ackCount() { return blockFor() - responses; } - - public boolean isLatencyForSnitch() - { - return false; - } } diff --git a/src/java/org/apache/cassandra/service/pager/PagingState.java b/src/java/org/apache/cassandra/service/pager/PagingState.java index f036f96f1dd6..8df2366d14fb 100644 --- a/src/java/org/apache/cassandra/service/pager/PagingState.java +++ b/src/java/org/apache/cassandra/service/pager/PagingState.java @@ -21,10 +21,11 @@ import java.nio.ByteBuffer; import java.util.*; -import org.apache.cassandra.schema.TableMetadata; +import com.google.common.annotations.VisibleForTesting; +import com.google.common.primitives.Ints; + import org.apache.cassandra.db.Clustering; import org.apache.cassandra.db.CompactTables; -import org.apache.cassandra.db.TypeSizes; import org.apache.cassandra.db.marshal.AbstractType; import org.apache.cassandra.db.marshal.BytesType; import org.apache.cassandra.db.marshal.CompositeType; @@ -34,10 +35,17 @@ import org.apache.cassandra.io.util.DataOutputBuffer; import org.apache.cassandra.io.util.DataOutputBufferFixed; import org.apache.cassandra.net.MessagingService; -import org.apache.cassandra.transport.ProtocolVersion; +import org.apache.cassandra.schema.TableMetadata; import org.apache.cassandra.transport.ProtocolException; -import org.apache.cassandra.utils.ByteBufferUtil; +import org.apache.cassandra.transport.ProtocolVersion; + +import static org.apache.cassandra.db.TypeSizes.sizeof; +import static org.apache.cassandra.db.TypeSizes.sizeofUnsignedVInt; +import static org.apache.cassandra.utils.ByteBufferUtil.*; +import static org.apache.cassandra.utils.vint.VIntCoding.computeUnsignedVIntSize; +import static org.apache.cassandra.utils.vint.VIntCoding.getUnsignedVInt; +@SuppressWarnings("WeakerAccess") public class PagingState { public final ByteBuffer partitionKey; // Can be null for single partition queries. @@ -53,92 +61,228 @@ public PagingState(ByteBuffer partitionKey, RowMark rowMark, int remaining, int this.remainingInPartition = remainingInPartition; } - public static PagingState deserialize(ByteBuffer bytes, ProtocolVersion protocolVersion) + public ByteBuffer serialize(ProtocolVersion protocolVersion) { - if (bytes == null) - return null; - - try (DataInputBuffer in = new DataInputBuffer(bytes, true)) + assert rowMark == null || protocolVersion == rowMark.protocolVersion; + try { - ByteBuffer pk; - RowMark mark; - int remaining, remainingInPartition; - if (protocolVersion.isSmallerOrEqualTo(ProtocolVersion.V3)) - { - pk = ByteBufferUtil.readWithShortLength(in); - mark = new RowMark(ByteBufferUtil.readWithShortLength(in), protocolVersion); - remaining = in.readInt(); - // Note that while 'in.available()' is theoretically an estimate of how many bytes are available - // without blocking, we know that since we're reading a ByteBuffer it will be exactly how many - // bytes remain to be read. And the reason we want to condition this is for backward compatility - // as we used to not set this. - remainingInPartition = in.available() > 0 ? in.readInt() : Integer.MAX_VALUE; - } - else - { - pk = ByteBufferUtil.readWithVIntLength(in); - mark = new RowMark(ByteBufferUtil.readWithVIntLength(in), protocolVersion); - remaining = (int)in.readUnsignedVInt(); - remainingInPartition = (int)in.readUnsignedVInt(); - } - return new PagingState(pk.hasRemaining() ? pk : null, - mark.mark.hasRemaining() ? mark : null, - remaining, - remainingInPartition); + return protocolVersion.isGreaterThan(ProtocolVersion.V3) ? modernSerialize() : legacySerialize(true); } catch (IOException e) { - throw new ProtocolException("Invalid value for the paging state"); + throw new RuntimeException(e); } } - public ByteBuffer serialize(ProtocolVersion protocolVersion) + public int serializedSize(ProtocolVersion protocolVersion) { assert rowMark == null || protocolVersion == rowMark.protocolVersion; - try (DataOutputBuffer out = new DataOutputBufferFixed(serializedSize(protocolVersion))) + + return protocolVersion.isGreaterThan(ProtocolVersion.V3) ? modernSerializedSize() : legacySerializedSize(true); + } + + /** + * It's possible to receive a V3 paging state on a V4 client session, and vice versa - so we cannot + * blindly rely on the protocol version provided. We must verify first that the buffer indeed contains + * a paging state that adheres to the protocol version provided, or, if not - see if it is in a different + * version, in which case we try the other format. + */ + public static PagingState deserialize(ByteBuffer bytes, ProtocolVersion protocolVersion) + { + if (bytes == null) + return null; + + try { - ByteBuffer pk = partitionKey == null ? ByteBufferUtil.EMPTY_BYTE_BUFFER : partitionKey; - ByteBuffer mark = rowMark == null ? ByteBufferUtil.EMPTY_BYTE_BUFFER : rowMark.mark; - if (protocolVersion.isSmallerOrEqualTo(ProtocolVersion.V3)) + /* + * We can't just attempt to deser twice, as we risk to misinterpet short/vint + * lengths and allocate huge byte arrays for readWithVIntLength() or, + * to a lesser extent, readWithShortLength() + */ + + if (protocolVersion.isGreaterThan(ProtocolVersion.V3)) { - ByteBufferUtil.writeWithShortLength(pk, out); - ByteBufferUtil.writeWithShortLength(mark, out); - out.writeInt(remaining); - out.writeInt(remainingInPartition); + if (isModernSerialized(bytes)) return modernDeserialize(bytes, protocolVersion); + if (isLegacySerialized(bytes)) return legacyDeserialize(bytes, ProtocolVersion.V3); } - else + + if (protocolVersion.isSmallerThan(ProtocolVersion.V4)) { - ByteBufferUtil.writeWithVIntLength(pk, out); - ByteBufferUtil.writeWithVIntLength(mark, out); - out.writeUnsignedVInt(remaining); - out.writeUnsignedVInt(remainingInPartition); + if (isLegacySerialized(bytes)) return legacyDeserialize(bytes, protocolVersion); + if (isModernSerialized(bytes)) return modernDeserialize(bytes, ProtocolVersion.V4); } - return out.buffer(); } catch (IOException e) { - throw new RuntimeException(e); + throw new ProtocolException("Invalid value for the paging state"); } + + throw new ProtocolException("Invalid value for the paging state"); } - public int serializedSize(ProtocolVersion protocolVersion) + /* + * Modern serde (> VERSION_3) + */ + + @SuppressWarnings({ "resource", "RedundantSuppression" }) + private ByteBuffer modernSerialize() throws IOException { - assert rowMark == null || protocolVersion == rowMark.protocolVersion; - ByteBuffer pk = partitionKey == null ? ByteBufferUtil.EMPTY_BYTE_BUFFER : partitionKey; - ByteBuffer mark = rowMark == null ? ByteBufferUtil.EMPTY_BYTE_BUFFER : rowMark.mark; - if (protocolVersion.isSmallerOrEqualTo(ProtocolVersion.V3)) - { - return ByteBufferUtil.serializedSizeWithShortLength(pk) - + ByteBufferUtil.serializedSizeWithShortLength(mark) - + 8; // remaining & remainingInPartition - } - else + DataOutputBuffer out = new DataOutputBufferFixed(modernSerializedSize()); + writeWithVIntLength(null == partitionKey ? EMPTY_BYTE_BUFFER : partitionKey, out); + writeWithVIntLength(null == rowMark ? EMPTY_BYTE_BUFFER : rowMark.mark, out); + out.writeUnsignedVInt(remaining); + out.writeUnsignedVInt(remainingInPartition); + return out.buffer(false); + } + + private static boolean isModernSerialized(ByteBuffer bytes) + { + int index = bytes.position(); + int limit = bytes.limit(); + + long partitionKeyLen = getUnsignedVInt(bytes, index, limit); + if (partitionKeyLen < 0) + return false; + index += computeUnsignedVIntSize(partitionKeyLen) + partitionKeyLen; + if (index >= limit) + return false; + + long rowMarkerLen = getUnsignedVInt(bytes, index, limit); + if (rowMarkerLen < 0) + return false; + index += computeUnsignedVIntSize(rowMarkerLen) + rowMarkerLen; + if (index >= limit) + return false; + + long remaining = getUnsignedVInt(bytes, index, limit); + if (remaining < 0) + return false; + index += computeUnsignedVIntSize(remaining); + if (index >= limit) + return false; + + long remainingInPartition = getUnsignedVInt(bytes, index, limit); + if (remainingInPartition < 0) + return false; + index += computeUnsignedVIntSize(remainingInPartition); + return index == limit; + } + + @SuppressWarnings({ "resource", "RedundantSuppression" }) + private static PagingState modernDeserialize(ByteBuffer bytes, ProtocolVersion protocolVersion) throws IOException + { + if (protocolVersion.isSmallerThan(ProtocolVersion.V4)) + throw new IllegalArgumentException(); + + DataInputBuffer in = new DataInputBuffer(bytes, false); + + ByteBuffer partitionKey = readWithVIntLength(in); + ByteBuffer rawMark = readWithVIntLength(in); + int remaining = Ints.checkedCast(in.readUnsignedVInt()); + int remainingInPartition = Ints.checkedCast(in.readUnsignedVInt()); + + return new PagingState(partitionKey.hasRemaining() ? partitionKey : null, + rawMark.hasRemaining() ? new RowMark(rawMark, protocolVersion) : null, + remaining, + remainingInPartition); + } + + private int modernSerializedSize() + { + return serializedSizeWithVIntLength(null == partitionKey ? EMPTY_BYTE_BUFFER : partitionKey) + + serializedSizeWithVIntLength(null == rowMark ? EMPTY_BYTE_BUFFER : rowMark.mark) + + sizeofUnsignedVInt(remaining) + + sizeofUnsignedVInt(remainingInPartition); + } + + /* + * Legacy serde (< VERSION_4) + * + * There are two versions of legacy PagingState format - one used by 2.1/2.2 and one used by 3.0+. + * The latter includes remainingInPartition count, while the former doesn't. + */ + + @VisibleForTesting + @SuppressWarnings({ "resource", "RedundantSuppression" }) + ByteBuffer legacySerialize(boolean withRemainingInPartition) throws IOException + { + DataOutputBuffer out = new DataOutputBufferFixed(legacySerializedSize(withRemainingInPartition)); + writeWithShortLength(null == partitionKey ? EMPTY_BYTE_BUFFER : partitionKey, out); + writeWithShortLength(null == rowMark ? EMPTY_BYTE_BUFFER : rowMark.mark, out); + out.writeInt(remaining); + if (withRemainingInPartition) + out.writeInt(remainingInPartition); + return out.buffer(false); + } + + private static boolean isLegacySerialized(ByteBuffer bytes) + { + int index = bytes.position(); + int limit = bytes.limit(); + + if (limit - index < 2) + return false; + short partitionKeyLen = bytes.getShort(index); + if (partitionKeyLen < 0) + return false; + index += 2 + partitionKeyLen; + + if (limit - index < 2) + return false; + short rowMarkerLen = bytes.getShort(index); + if (rowMarkerLen < 0) + return false; + index += 2 + rowMarkerLen; + + if (limit - index < 4) + return false; + int remaining = bytes.getInt(index); + if (remaining < 0) + return false; + index += 4; + + // V3 encoded by 2.1/2.2 - sans remainingInPartition + if (index == limit) + return true; + + if (limit - index == 4) { - return ByteBufferUtil.serializedSizeWithVIntLength(pk) - + ByteBufferUtil.serializedSizeWithVIntLength(mark) - + TypeSizes.sizeofUnsignedVInt(remaining) - + TypeSizes.sizeofUnsignedVInt(remainingInPartition); + int remainingInPartition = bytes.getInt(index); + return remainingInPartition >= 0; // the value must make sense } + return false; + } + + @SuppressWarnings({ "resource", "RedundantSuppression" }) + private static PagingState legacyDeserialize(ByteBuffer bytes, ProtocolVersion protocolVersion) throws IOException + { + if (protocolVersion.isGreaterThan(ProtocolVersion.V3)) + throw new IllegalArgumentException(); + + DataInputBuffer in = new DataInputBuffer(bytes, false); + + ByteBuffer partitionKey = readWithShortLength(in); + ByteBuffer rawMark = readWithShortLength(in); + int remaining = in.readInt(); + /* + * 2.1/2.2 implementations of V3 protocol did not write remainingInPartition, but C* 3.0+ does, so we need + * to handle both variants of V3 serialization for compatibility. + */ + int remainingInPartition = in.available() > 0 ? in.readInt() : Integer.MAX_VALUE; + + return new PagingState(partitionKey.hasRemaining() ? partitionKey : null, + rawMark.hasRemaining() ? new RowMark(rawMark, protocolVersion) : null, + remaining, + remainingInPartition); + } + + @VisibleForTesting + int legacySerializedSize(boolean withRemainingInPartition) + { + return serializedSizeWithShortLength(null == partitionKey ? EMPTY_BYTE_BUFFER : partitionKey) + + serializedSizeWithShortLength(null == rowMark ? EMPTY_BYTE_BUFFER : rowMark.mark) + + sizeof(remaining) + + (withRemainingInPartition ? sizeof(remainingInPartition) : 0); } @Override @@ -163,7 +307,7 @@ public final boolean equals(Object o) public String toString() { return String.format("PagingState(key=%s, cellname=%s, remaining=%d, remainingInPartition=%d", - partitionKey != null ? ByteBufferUtil.bytesToHex(partitionKey) : null, + partitionKey != null ? bytesToHex(partitionKey) : null, rowMark, remaining, remainingInPartition); @@ -217,7 +361,7 @@ public static RowMark create(TableMetadata metadata, Row row, ProtocolVersion pr // If the last returned row has no cell, this means in 2.1/2.2 terms that we stopped on the row // marker. Note that this shouldn't happen if the table is COMPACT. assert !metadata.isCompactTable(); - mark = encodeCellName(metadata, row.clustering(), ByteBufferUtil.EMPTY_BYTE_BUFFER, null); + mark = encodeCellName(metadata, row.clustering(), EMPTY_BYTE_BUFFER, null); } else { @@ -268,7 +412,7 @@ private static ByteBuffer encodeCellName(TableMetadata metadata, Clustering clus { if (isStatic) { - values[i] = ByteBufferUtil.EMPTY_BYTE_BUFFER; + values[i] = EMPTY_BYTE_BUFFER; continue; } @@ -336,7 +480,7 @@ public final boolean equals(Object o) @Override public String toString() { - return mark == null ? "null" : ByteBufferUtil.bytesToHex(mark); + return mark == null ? "null" : bytesToHex(mark); } } } diff --git a/src/java/org/apache/cassandra/service/paxos/AbstractPaxosCallback.java b/src/java/org/apache/cassandra/service/paxos/AbstractPaxosCallback.java index 90bfc5d9d759..ab24f50efb8c 100644 --- a/src/java/org/apache/cassandra/service/paxos/AbstractPaxosCallback.java +++ b/src/java/org/apache/cassandra/service/paxos/AbstractPaxosCallback.java @@ -1,4 +1,3 @@ -package org.apache.cassandra.service.paxos; /* * * Licensed to the Apache Software Foundation (ASF) under one @@ -19,18 +18,19 @@ * under the License. * */ - +package org.apache.cassandra.service.paxos; import java.util.concurrent.CountDownLatch; -import java.util.concurrent.TimeUnit; import org.apache.cassandra.config.DatabaseDescriptor; import org.apache.cassandra.db.ConsistencyLevel; import org.apache.cassandra.db.WriteType; import org.apache.cassandra.exceptions.WriteTimeoutException; -import org.apache.cassandra.net.IAsyncCallback; +import org.apache.cassandra.net.RequestCallback; + +import static java.util.concurrent.TimeUnit.NANOSECONDS; -public abstract class AbstractPaxosCallback implements IAsyncCallback +public abstract class AbstractPaxosCallback implements RequestCallback { protected final CountDownLatch latch; protected final int targets; @@ -45,11 +45,6 @@ public AbstractPaxosCallback(int targets, ConsistencyLevel consistency, long que this.queryStartNanoTime = queryStartNanoTime; } - public boolean isLatencyForSnitch() - { - return false; - } - public int getResponseCount() { return (int) (targets - latch.getCount()); @@ -59,8 +54,8 @@ public void await() throws WriteTimeoutException { try { - long timeout = TimeUnit.MILLISECONDS.toNanos(DatabaseDescriptor.getWriteRpcTimeout()) - (System.nanoTime() - queryStartNanoTime); - if (!latch.await(timeout, TimeUnit.NANOSECONDS)) + long timeout = DatabaseDescriptor.getWriteRpcTimeout(NANOSECONDS) - (System.nanoTime() - queryStartNanoTime); + if (!latch.await(timeout, NANOSECONDS)) throw new WriteTimeoutException(WriteType.CAS, consistency, getResponseCount(), targets); } catch (InterruptedException ex) diff --git a/src/java/org/apache/cassandra/service/paxos/CommitVerbHandler.java b/src/java/org/apache/cassandra/service/paxos/CommitVerbHandler.java index a702a4dfd750..12466ddc691e 100644 --- a/src/java/org/apache/cassandra/service/paxos/CommitVerbHandler.java +++ b/src/java/org/apache/cassandra/service/paxos/CommitVerbHandler.java @@ -20,19 +20,20 @@ */ package org.apache.cassandra.service.paxos; -import org.apache.cassandra.db.WriteResponse; import org.apache.cassandra.net.IVerbHandler; -import org.apache.cassandra.net.MessageIn; +import org.apache.cassandra.net.Message; import org.apache.cassandra.net.MessagingService; import org.apache.cassandra.tracing.Tracing; public class CommitVerbHandler implements IVerbHandler { - public void doVerb(MessageIn message, int id) + public static final CommitVerbHandler instance = new CommitVerbHandler(); + + public void doVerb(Message message) { PaxosState.commit(message.payload); - Tracing.trace("Enqueuing acknowledge to {}", message.from); - MessagingService.instance().sendReply(WriteResponse.createMessage(), id, message.from); + Tracing.trace("Enqueuing acknowledge to {}", message.from()); + MessagingService.instance().send(message.emptyResponse(), message.from()); } } diff --git a/src/java/org/apache/cassandra/service/paxos/PrepareCallback.java b/src/java/org/apache/cassandra/service/paxos/PrepareCallback.java index ed70e964c03e..26890a9b102a 100644 --- a/src/java/org/apache/cassandra/service/paxos/PrepareCallback.java +++ b/src/java/org/apache/cassandra/service/paxos/PrepareCallback.java @@ -36,7 +36,7 @@ import org.slf4j.LoggerFactory; import org.apache.cassandra.db.SystemKeyspace; -import org.apache.cassandra.net.MessageIn; +import org.apache.cassandra.net.Message; import org.apache.cassandra.utils.UUIDGen; public class PrepareCallback extends AbstractPaxosCallback @@ -53,16 +53,16 @@ public class PrepareCallback extends AbstractPaxosCallback public PrepareCallback(DecoratedKey key, TableMetadata metadata, int targets, ConsistencyLevel consistency, long queryStartNanoTime) { super(targets, consistency, queryStartNanoTime); - // need to inject the right key in the empty commit so comparing with empty commits in the reply works as expected + // need to inject the right key in the empty commit so comparing with empty commits in the response works as expected mostRecentCommit = Commit.emptyCommit(key, metadata); mostRecentInProgressCommit = Commit.emptyCommit(key, metadata); mostRecentInProgressCommitWithUpdate = Commit.emptyCommit(key, metadata); } - public synchronized void response(MessageIn message) + public synchronized void onResponse(Message message) { PrepareResponse response = message.payload; - logger.trace("Prepare response {} from {}", response, message.from); + logger.trace("Prepare response {} from {}", response, message.from()); // In case of clock skew, another node could be proposing with ballot that are quite a bit // older than our own. In that case, we record the more recent commit we've received to make @@ -78,7 +78,7 @@ public synchronized void response(MessageIn message) return; } - commitsByReplica.put(message.from, response.mostRecentCommit); + commitsByReplica.put(message.from(), response.mostRecentCommit); if (response.mostRecentCommit.isAfter(mostRecentCommit)) mostRecentCommit = response.mostRecentCommit; diff --git a/src/java/org/apache/cassandra/service/paxos/PrepareVerbHandler.java b/src/java/org/apache/cassandra/service/paxos/PrepareVerbHandler.java index 2750b7611207..157630f277af 100644 --- a/src/java/org/apache/cassandra/service/paxos/PrepareVerbHandler.java +++ b/src/java/org/apache/cassandra/service/paxos/PrepareVerbHandler.java @@ -19,23 +19,22 @@ * under the License. * */ - - import org.apache.cassandra.net.IVerbHandler; -import org.apache.cassandra.net.MessageIn; -import org.apache.cassandra.net.MessageOut; +import org.apache.cassandra.net.Message; import org.apache.cassandra.net.MessagingService; public class PrepareVerbHandler implements IVerbHandler { + public static PrepareVerbHandler instance = new PrepareVerbHandler(); + public static PrepareResponse doPrepare(Commit toPrepare) { return PaxosState.prepare(toPrepare); } - public void doVerb(MessageIn message, int id) + public void doVerb(Message message) { - MessageOut reply = new MessageOut(MessagingService.Verb.REQUEST_RESPONSE, doPrepare(message.payload), PrepareResponse.serializer); - MessagingService.instance().sendReply(reply, id, message.from); + Message reply = message.responseWith(doPrepare(message.payload)); + MessagingService.instance().send(reply, message.from()); } } diff --git a/src/java/org/apache/cassandra/service/paxos/ProposeCallback.java b/src/java/org/apache/cassandra/service/paxos/ProposeCallback.java index c9cb1f0604b3..7e755a03d8a6 100644 --- a/src/java/org/apache/cassandra/service/paxos/ProposeCallback.java +++ b/src/java/org/apache/cassandra/service/paxos/ProposeCallback.java @@ -27,7 +27,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import org.apache.cassandra.net.MessageIn; +import org.apache.cassandra.net.Message; /** * ProposeCallback has two modes of operation, controlled by the failFast parameter. @@ -35,7 +35,7 @@ * In failFast mode, we will return a failure as soon as a majority of nodes reject * the proposal. This is used when replaying a proposal from an earlier leader. * - * Otherwise, we wait for either all replicas to reply or until we achieve + * Otherwise, we wait for either all replicas to respond or until we achieve * the desired quorum. We continue to wait for all replicas even after we know we cannot succeed * because we need to know if no node at all have accepted or if at least one has. * In the former case, a proposer is guaranteed no-one will @@ -57,9 +57,9 @@ public ProposeCallback(int totalTargets, int requiredTargets, boolean failFast, this.failFast = failFast; } - public void response(MessageIn msg) + public void onResponse(Message msg) { - logger.trace("Propose response {} from {}", msg.payload, msg.from); + logger.trace("Propose response {} from {}", msg.payload, msg.from()); if (msg.payload) accepts.incrementAndGet(); diff --git a/src/java/org/apache/cassandra/service/paxos/ProposeVerbHandler.java b/src/java/org/apache/cassandra/service/paxos/ProposeVerbHandler.java index 81c90174e054..5a20b674580f 100644 --- a/src/java/org/apache/cassandra/service/paxos/ProposeVerbHandler.java +++ b/src/java/org/apache/cassandra/service/paxos/ProposeVerbHandler.java @@ -19,24 +19,22 @@ * under the License. * */ - - import org.apache.cassandra.net.IVerbHandler; -import org.apache.cassandra.net.MessageIn; -import org.apache.cassandra.net.MessageOut; +import org.apache.cassandra.net.Message; import org.apache.cassandra.net.MessagingService; -import org.apache.cassandra.utils.BooleanSerializer; public class ProposeVerbHandler implements IVerbHandler { + public static final ProposeVerbHandler instance = new ProposeVerbHandler(); + public static Boolean doPropose(Commit proposal) { return PaxosState.propose(proposal); } - public void doVerb(MessageIn message, int id) + public void doVerb(Message message) { - MessageOut reply = new MessageOut(MessagingService.Verb.REQUEST_RESPONSE, doPropose(message.payload), BooleanSerializer.serializer); - MessagingService.instance().sendReply(reply, id, message.from); + Message reply = message.responseWith(doPropose(message.payload)); + MessagingService.instance().send(reply, message.from()); } } diff --git a/src/java/org/apache/cassandra/service/reads/AbstractReadExecutor.java b/src/java/org/apache/cassandra/service/reads/AbstractReadExecutor.java index 174ed7bff6f7..27e928072577 100644 --- a/src/java/org/apache/cassandra/service/reads/AbstractReadExecutor.java +++ b/src/java/org/apache/cassandra/service/reads/AbstractReadExecutor.java @@ -17,8 +17,6 @@ */ package org.apache.cassandra.service.reads; -import java.util.concurrent.TimeUnit; - import com.google.common.base.Preconditions; import com.google.common.base.Predicates; @@ -43,7 +41,7 @@ import org.apache.cassandra.locator.InetAddressAndPort; import org.apache.cassandra.locator.Replica; import org.apache.cassandra.locator.ReplicaCollection; -import org.apache.cassandra.net.MessageOut; +import org.apache.cassandra.net.Message; import org.apache.cassandra.net.MessagingService; import org.apache.cassandra.service.reads.repair.ReadRepair; import org.apache.cassandra.service.StorageProxy.LocalReadRunnable; @@ -51,6 +49,7 @@ import org.apache.cassandra.tracing.Tracing; import static com.google.common.collect.Iterables.all; +import static java.util.concurrent.TimeUnit.NANOSECONDS; /** * Sends a read request to the replicas needed to satisfy a given ConsistencyLevel. @@ -95,7 +94,7 @@ public abstract class AbstractReadExecutor // we stop being compatible with pre-3.0 nodes. int digestVersion = MessagingService.current_version; for (Replica replica : replicaPlan.contacts()) - digestVersion = Math.min(digestVersion, MessagingService.instance().getVersion(replica.endpoint())); + digestVersion = Math.min(digestVersion, MessagingService.instance().versions.get(replica.endpoint())); command.setDigestVersion(digestVersion); } @@ -132,6 +131,7 @@ protected void makeDigestRequests(Iterable replicas) private void makeRequests(ReadCommand readCommand, Iterable replicas) { boolean hasLocalEndpoint = false; + Message message = null; for (Replica replica: replicas) { @@ -146,8 +146,11 @@ private void makeRequests(ReadCommand readCommand, Iterable replicas) if (traceState != null) traceState.trace("reading {} from {}", readCommand.isDigestQuery() ? "digest" : "data", endpoint); - MessageOut message = readCommand.createMessage(); - MessagingService.instance().sendRRWithFailure(message, endpoint, handler); + + if (null == message) + message = readCommand.createMessage(false); + + MessagingService.instance().sendWithCallback(message, endpoint, handler); } // We delay the local (potentially blocking) read till the end to avoid stalling remote requests. @@ -213,10 +216,10 @@ public static AbstractReadExecutor getReadExecutor(SinglePartitionReadCommand co boolean shouldSpeculateAndMaybeWait() { // no latency information, or we're overloaded - if (cfs.sampleReadLatencyNanos > TimeUnit.MILLISECONDS.toNanos(command.getTimeout())) + if (cfs.sampleReadLatencyNanos > command.getTimeout(NANOSECONDS)) return false; - return !handler.await(cfs.sampleReadLatencyNanos, TimeUnit.NANOSECONDS); + return !handler.await(cfs.sampleReadLatencyNanos, NANOSECONDS); } ReplicaPlan.ForTokenRead replicaPlan() @@ -261,7 +264,7 @@ public SpeculatingReadExecutor(ColumnFamilyStore cfs, { // We're hitting additional targets for read repair (??). Since our "extra" replica is the least- // preferred by the snitch, we do an extra data read to start with against a replica more - // likely to reply; better to let RR fail than the entire query. + // likely to respond; better to let RR fail than the entire query. super(cfs, command, replicaPlan, replicaPlan.blockFor() < replicaPlan.contacts().size() ? 2 : 1, queryStartNanoTime); } @@ -308,7 +311,7 @@ public void maybeTryAdditionalReplicas() if (traceState != null) traceState.trace("speculating read retry on {}", extraReplica); logger.trace("speculating read retry on {}", extraReplica); - MessagingService.instance().sendRRWithFailure(retryCommand.createMessage(), extraReplica.endpoint(), handler); + MessagingService.instance().sendWithCallback(retryCommand.createMessage(false), extraReplica.endpoint(), handler); } } diff --git a/src/java/org/apache/cassandra/service/reads/DataResolver.java b/src/java/org/apache/cassandra/service/reads/DataResolver.java index 03f718f25a41..45bf9186004f 100644 --- a/src/java/org/apache/cassandra/service/reads/DataResolver.java +++ b/src/java/org/apache/cassandra/service/reads/DataResolver.java @@ -43,14 +43,13 @@ import org.apache.cassandra.db.transform.Transformation; import org.apache.cassandra.locator.Endpoints; import org.apache.cassandra.locator.ReplicaPlan; -import org.apache.cassandra.net.MessageIn; +import org.apache.cassandra.net.Message; import org.apache.cassandra.schema.TableMetadata; import org.apache.cassandra.service.reads.repair.ReadRepair; import org.apache.cassandra.service.reads.repair.RepairedDataTracker; import org.apache.cassandra.service.reads.repair.RepairedDataVerifier; import static com.google.common.collect.Iterables.*; -import static org.apache.cassandra.db.partitions.UnfilteredPartitionIterators.MergeListener; public class DataResolver, P extends ReplicaPlan.ForRead> extends ResponseResolver { @@ -80,10 +79,10 @@ public PartitionIterator resolve() { // We could get more responses while this method runs, which is ok (we're happy to ignore any response not here // at the beginning of this method), so grab the response count once and use that through the method. - Collection> messages = responses.snapshot(); + Collection> messages = responses.snapshot(); assert !any(messages, msg -> msg.payload.isDigestResponse()); - E replicas = replicaPlan().candidates().select(transform(messages, msg -> msg.from), false); + E replicas = replicaPlan().candidates().select(transform(messages, msg -> msg.from()), false); List iters = new ArrayList<>( Collections2.transform(messages, msg -> msg.payload.makeIterator(command))); assert replicas.size() == iters.size(); @@ -95,9 +94,9 @@ public PartitionIterator resolve() if (repairedDataTracker != null) { messages.forEach(msg -> { - if (msg.payload.mayIncludeRepairedDigest() && replicas.byEndpoint().get(msg.from).isFull()) + if (msg.payload.mayIncludeRepairedDigest() && replicas.byEndpoint().get(msg.from()).isFull()) { - repairedDataTracker.recordDigest(msg.from, + repairedDataTracker.recordDigest(msg.from(), msg.payload.repairedDataDigest(), msg.payload.isRepairedDigestConclusive()); } @@ -157,7 +156,7 @@ private UnfilteredPartitionIterator mergeWithShortReadProtection(List m.from + " => " + m.payload.toDebugString(command, partitionKey))); + return Joiner.on(",\n").join(transform(getMessages().snapshot(), m -> m.from() + " => " + m.payload.toDebugString(command, partitionKey))); } private UnfilteredPartitionIterators.MergeListener wrapMergeListener(UnfilteredPartitionIterators.MergeListener partitionListener, diff --git a/src/java/org/apache/cassandra/service/reads/DigestResolver.java b/src/java/org/apache/cassandra/service/reads/DigestResolver.java index 899baf9830f0..cf7ec315ca97 100644 --- a/src/java/org/apache/cassandra/service/reads/DigestResolver.java +++ b/src/java/org/apache/cassandra/service/reads/DigestResolver.java @@ -33,16 +33,15 @@ import org.apache.cassandra.locator.Replica; import org.apache.cassandra.locator.InetAddressAndPort; import org.apache.cassandra.locator.ReplicaPlan; -import org.apache.cassandra.net.MessageIn; +import org.apache.cassandra.net.Message; import org.apache.cassandra.service.reads.repair.NoopReadRepair; -import org.apache.cassandra.service.reads.repair.ReadRepair; import org.apache.cassandra.utils.ByteBufferUtil; import static com.google.common.collect.Iterables.any; public class DigestResolver, P extends ReplicaPlan.ForRead> extends ResponseResolver { - private volatile MessageIn dataResponse; + private volatile Message dataResponse; public DigestResolver(ReadCommand command, ReplicaPlan.Shared replicaPlan, long queryStartNanoTime) { @@ -52,10 +51,10 @@ public DigestResolver(ReadCommand command, ReplicaPlan.Shared replicaPlan, } @Override - public void preprocess(MessageIn message) + public void preprocess(Message message) { super.preprocess(message); - Replica replica = replicaPlan().getReplicaFor(message.from); + Replica replica = replicaPlan().getReplicaFor(message.from()); if (dataResponse == null && !message.payload.isDigestResponse() && replica.isFull()) dataResponse = message; } @@ -66,18 +65,18 @@ public boolean hasTransientResponse() return hasTransientResponse(responses.snapshot()); } - private boolean hasTransientResponse(Collection> responses) + private boolean hasTransientResponse(Collection> responses) { return any(responses, msg -> !msg.payload.isDigestResponse() - && replicaPlan().getReplicaFor(msg.from).isTransient()); + && replicaPlan().getReplicaFor(msg.from()).isTransient()); } public PartitionIterator getData() { assert isDataPresent(); - Collection> responses = this.responses.snapshot(); + Collection> responses = this.responses.snapshot(); if (!hasTransientResponse(responses)) { @@ -92,9 +91,9 @@ public PartitionIterator getData() dataResolver.preprocess(dataResponse); // Reconcile with transient replicas - for (MessageIn response : responses) + for (Message response : responses) { - Replica replica = replicaPlan().getReplicaFor(response.from); + Replica replica = replicaPlan().getReplicaFor(response.from()); if (replica.isTransient()) dataResolver.preprocess(response); } @@ -109,9 +108,14 @@ public boolean responsesMatch() // validate digests against each other; return false immediately on mismatch. ByteBuffer digest = null; - for (MessageIn message : responses.snapshot()) + Collection> snapshot = responses.snapshot(); + if (snapshot.size() <= 1) + return true; + + // TODO: should also not calculate if only one full node + for (Message message : snapshot) { - if (replicaPlan().getReplicaFor(message.from).isTransient()) + if (replicaPlan().getReplicaFor(message.from()).isTransient()) continue; ByteBuffer newDigest = message.payload.digest(command); @@ -138,10 +142,10 @@ public DigestResolverDebugResult[] getDigestsByEndpoint() DigestResolverDebugResult[] ret = new DigestResolverDebugResult[responses.size()]; for (int i = 0; i < responses.size(); i++) { - MessageIn message = responses.get(i); + Message message = responses.get(i); ReadResponse response = message.payload; String digestHex = ByteBufferUtil.bytesToHex(response.digest(command)); - ret[i] = new DigestResolverDebugResult(message.from, digestHex, message.payload.isDigestResponse()); + ret[i] = new DigestResolverDebugResult(message.from(), digestHex, message.payload.isDigestResponse()); } return ret; } diff --git a/src/java/org/apache/cassandra/service/reads/ReadCallback.java b/src/java/org/apache/cassandra/service/reads/ReadCallback.java index 7a2385c6ca4b..2968dbce09b5 100644 --- a/src/java/org/apache/cassandra/service/reads/ReadCallback.java +++ b/src/java/org/apache/cassandra/service/reads/ReadCallback.java @@ -17,7 +17,6 @@ */ package org.apache.cassandra.service.reads; -import java.util.Collections; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.TimeUnit; @@ -36,14 +35,15 @@ import org.apache.cassandra.exceptions.RequestFailureReason; import org.apache.cassandra.locator.Endpoints; import org.apache.cassandra.locator.InetAddressAndPort; -import org.apache.cassandra.net.IAsyncCallbackWithFailure; -import org.apache.cassandra.net.MessageIn; -import org.apache.cassandra.net.MessagingService; +import org.apache.cassandra.net.RequestCallback; +import org.apache.cassandra.net.Message; +import org.apache.cassandra.net.Verb; import org.apache.cassandra.tracing.Tracing; -import org.apache.cassandra.utils.FBUtilities; import org.apache.cassandra.utils.concurrent.SimpleCondition; -public class ReadCallback, P extends ReplicaPlan.ForRead> implements IAsyncCallbackWithFailure +import static java.util.concurrent.TimeUnit.MILLISECONDS; + +public class ReadCallback, P extends ReplicaPlan.ForRead> implements RequestCallback { protected static final Logger logger = LoggerFactory.getLogger( ReadCallback.class ); @@ -98,7 +98,7 @@ public boolean await(long timePastStart, TimeUnit unit) public void awaitResults() throws ReadFailureException, ReadTimeoutException { - boolean signaled = await(command.getTimeout(), TimeUnit.MILLISECONDS); + boolean signaled = await(command.getTimeout(MILLISECONDS), TimeUnit.MILLISECONDS); boolean failed = failures > 0 && blockFor + failures > replicaPlan().contacts().size(); if (signaled && !failed) return; @@ -125,10 +125,10 @@ public int blockFor() return blockFor; } - public void response(MessageIn message) + public void onResponse(Message message) { resolver.preprocess(message); - int n = waitingFor(message.from) + int n = waitingFor(message.from()) ? recievedUpdater.incrementAndGet(this) : received; @@ -146,15 +146,14 @@ private boolean waitingFor(InetAddressAndPort from) public void response(ReadResponse result) { - MessageIn message = MessageIn.create(FBUtilities.getBroadcastAddressAndPort(), - result, - Collections.emptyMap(), - MessagingService.Verb.INTERNAL_RESPONSE, - MessagingService.current_version); - response(message); + Verb kind = command.isRangeRequest() ? Verb.RANGE_RSP : Verb.READ_RSP; + Message message = Message.internalResponse(kind, result); + onResponse(message); } - public boolean isLatencyForSnitch() + + @Override + public boolean trackLatencyForSnitch() { return true; } @@ -171,4 +170,10 @@ public void onFailure(InetAddressAndPort from, RequestFailureReason failureReaso if (blockFor + n > replicaPlan().contacts().size()) condition.signalAll(); } + + @Override + public boolean invokeOnFailure() + { + return true; + } } diff --git a/src/java/org/apache/cassandra/service/reads/ResponseResolver.java b/src/java/org/apache/cassandra/service/reads/ResponseResolver.java index aaead8439fa6..8e15c1a9a018 100644 --- a/src/java/org/apache/cassandra/service/reads/ResponseResolver.java +++ b/src/java/org/apache/cassandra/service/reads/ResponseResolver.java @@ -24,7 +24,7 @@ import org.apache.cassandra.db.ReadResponse; import org.apache.cassandra.locator.Endpoints; import org.apache.cassandra.locator.ReplicaPlan; -import org.apache.cassandra.net.MessageIn; +import org.apache.cassandra.net.Message; import org.apache.cassandra.utils.concurrent.Accumulator; public abstract class ResponseResolver, P extends ReplicaPlan.ForRead> @@ -35,7 +35,7 @@ public abstract class ResponseResolver, P extends Replica protected final ReplicaPlan.Shared replicaPlan; // Accumulator gives us non-blocking thread-safety with optimal algorithmic constraints - protected final Accumulator> responses; + protected final Accumulator> responses; protected final long queryStartNanoTime; public ResponseResolver(ReadCommand command, ReplicaPlan.Shared replicaPlan, long queryStartNanoTime) @@ -53,9 +53,9 @@ protected P replicaPlan() public abstract boolean isDataPresent(); - public void preprocess(MessageIn message) + public void preprocess(Message message) { - if (replicaPlan().getReplicaFor(message.from).isTransient() && + if (replicaPlan().getReplicaFor(message.from()).isTransient() && message.payload.isDigestResponse()) throw new IllegalArgumentException("Digest response received from transient replica"); @@ -71,7 +71,7 @@ public void preprocess(MessageIn message) } } - public Accumulator> getMessages() + public Accumulator> getMessages() { return responses; } diff --git a/src/java/org/apache/cassandra/service/reads/ShortReadPartitionsProtection.java b/src/java/org/apache/cassandra/service/reads/ShortReadPartitionsProtection.java index 7b7c4d339fc1..8735ceddaca8 100644 --- a/src/java/org/apache/cassandra/service/reads/ShortReadPartitionsProtection.java +++ b/src/java/org/apache/cassandra/service/reads/ShortReadPartitionsProtection.java @@ -189,7 +189,7 @@ UnfilteredPartitionIterator executeReadCommand(ReadCommand cmd, ReplicaPlan.Shar { if (source.isTransient()) cmd = cmd.copyAsTransientQuery(source); - MessagingService.instance().sendRRWithFailure(cmd.createMessage(), source.endpoint(), handler); + MessagingService.instance().sendWithCallback(cmd.createMessage(false), source.endpoint(), handler); } // We don't call handler.get() because we want to preserve tombstones since we're still in the middle of merging node results. diff --git a/src/java/org/apache/cassandra/service/reads/repair/AbstractReadRepair.java b/src/java/org/apache/cassandra/service/reads/repair/AbstractReadRepair.java index 761ffb0233a6..79c124b5ef4d 100644 --- a/src/java/org/apache/cassandra/service/reads/repair/AbstractReadRepair.java +++ b/src/java/org/apache/cassandra/service/reads/repair/AbstractReadRepair.java @@ -18,7 +18,6 @@ package org.apache.cassandra.service.reads.repair; -import java.util.concurrent.TimeUnit; import java.util.function.Consumer; import com.google.common.base.Preconditions; @@ -40,15 +39,17 @@ import org.apache.cassandra.locator.Replica; import org.apache.cassandra.locator.ReplicaPlan; import org.apache.cassandra.metrics.ReadRepairMetrics; -import org.apache.cassandra.net.MessageOut; +import org.apache.cassandra.net.Message; +import org.apache.cassandra.net.MessageFlag; import org.apache.cassandra.net.MessagingService; -import org.apache.cassandra.net.ParameterType; import org.apache.cassandra.service.StorageProxy; import org.apache.cassandra.service.reads.DataResolver; import org.apache.cassandra.service.reads.DigestResolver; import org.apache.cassandra.service.reads.ReadCallback; import org.apache.cassandra.tracing.Tracing; +import static java.util.concurrent.TimeUnit.NANOSECONDS; + public abstract class AbstractReadRepair, P extends ReplicaPlan.ForRead> implements ReadRepair { @@ -113,12 +114,9 @@ void sendReadCommand(Replica to, ReadCallback readCallback, boolean speculative) else type = to.isFull() ? "full" : "transient"; Tracing.trace("Enqueuing {} data read to {}", type, to); } - MessageOut message = command.createMessage(); // if enabled, request additional info about repaired data from any full replicas - if (command.isTrackingRepairedStatus() && to.isFull()) - message = message.withParameter(ParameterType.TRACK_REPAIRED_DATA, MessagingService.ONE_BYTE); - - MessagingService.instance().sendRRWithFailure(message, to.endpoint(), readCallback); + Message message = command.createMessage(command.isTrackingRepairedStatus() && to.isFull()); + MessagingService.instance().sendWithCallback(message, to.endpoint(), readCallback); } abstract Meter getRepairMeter(); @@ -160,7 +158,7 @@ private boolean shouldSpeculate() ConsistencyLevel speculativeCL = consistency.isDatacenterLocal() ? ConsistencyLevel.LOCAL_QUORUM : ConsistencyLevel.QUORUM; return consistency != ConsistencyLevel.EACH_QUORUM && consistency.satisfies(speculativeCL, cfs.keyspace) - && cfs.sampleReadLatencyNanos <= TimeUnit.MILLISECONDS.toNanos(command.getTimeout()); + && cfs.sampleReadLatencyNanos <= command.getTimeout(NANOSECONDS); } public void maybeSendAdditionalReads() @@ -171,7 +169,7 @@ public void maybeSendAdditionalReads() if (repair == null) return; - if (shouldSpeculate() && !repair.readCallback.await(cfs.sampleReadLatencyNanos, TimeUnit.NANOSECONDS)) + if (shouldSpeculate() && !repair.readCallback.await(cfs.sampleReadLatencyNanos, NANOSECONDS)) { Replica uncontacted = replicaPlan().firstUncontactedCandidate(Predicates.alwaysTrue()); if (uncontacted == null) diff --git a/src/java/org/apache/cassandra/service/reads/repair/BlockingPartitionRepair.java b/src/java/org/apache/cassandra/service/reads/repair/BlockingPartitionRepair.java index 624c78f310c7..220ada5a5ab7 100644 --- a/src/java/org/apache/cassandra/service/reads/repair/BlockingPartitionRepair.java +++ b/src/java/org/apache/cassandra/service/reads/repair/BlockingPartitionRepair.java @@ -45,15 +45,16 @@ import org.apache.cassandra.locator.Replicas; import org.apache.cassandra.locator.InOurDcTester; import org.apache.cassandra.metrics.ReadRepairMetrics; -import org.apache.cassandra.net.IAsyncCallback; -import org.apache.cassandra.net.MessageIn; -import org.apache.cassandra.net.MessageOut; +import org.apache.cassandra.net.RequestCallback; +import org.apache.cassandra.net.Message; import org.apache.cassandra.net.MessagingService; import org.apache.cassandra.schema.TableId; import org.apache.cassandra.tracing.Tracing; +import static org.apache.cassandra.net.Verb.*; + public class BlockingPartitionRepair, P extends ReplicaPlan.ForRead> - extends AbstractFuture implements IAsyncCallback + extends AbstractFuture implements RequestCallback { private final DecoratedKey key; private final P replicaPlan; @@ -111,15 +112,9 @@ void ack(InetAddressAndPort from) } @Override - public void response(MessageIn msg) - { - ack(msg.from); - } - - @Override - public boolean isLatencyForSnitch() + public void onResponse(Message msg) { - return false; + ack(msg.from()); } private static PartitionUpdate extractUpdate(Mutation mutation) @@ -138,9 +133,9 @@ private PartitionUpdate mergeUnackedUpdates() } @VisibleForTesting - protected void sendRR(MessageOut message, InetAddressAndPort endpoint) + protected void sendRR(Message message, InetAddressAndPort endpoint) { - MessagingService.instance().sendRR(message, endpoint, this); + MessagingService.instance().sendWithCallback(message, endpoint, this); } public void sendInitialRepairs() @@ -157,7 +152,7 @@ public void sendInitialRepairs() Tracing.trace("Sending read-repair-mutation to {}", destination); // use a separate verb here to avoid writing hints on timeouts - sendRR(mutation.createMessage(MessagingService.Verb.READ_REPAIR), destination.endpoint()); + sendRR(Message.out(READ_REPAIR_REQ, mutation), destination.endpoint()); ColumnFamilyStore.metricsFor(tableId).readRepairRequests.mark(); if (!shouldBlockOn.test(destination.endpoint())) @@ -214,7 +209,7 @@ public void maybeSendAdditionalWrites(long timeout, TimeUnit timeoutUnit) for (Replica replica : newCandidates) { - int versionIdx = msgVersionIdx(MessagingService.instance().getVersion(replica.endpoint())); + int versionIdx = msgVersionIdx(MessagingService.instance().versions.get(replica.endpoint())); Mutation mutation = versionedMutations[versionIdx]; @@ -232,7 +227,7 @@ public void maybeSendAdditionalWrites(long timeout, TimeUnit timeoutUnit) } Tracing.trace("Sending speculative read-repair-mutation to {}", replica); - sendRR(mutation.createMessage(MessagingService.Verb.READ_REPAIR), replica.endpoint()); + sendRR(Message.out(READ_REPAIR_REQ, mutation), replica.endpoint()); ReadRepairDiagnostics.speculatedWrite(this, replica.endpoint(), mutation); } } diff --git a/src/java/org/apache/cassandra/service/reads/repair/BlockingReadRepair.java b/src/java/org/apache/cassandra/service/reads/repair/BlockingReadRepair.java index 8016b8944a41..ef624d69e144 100644 --- a/src/java/org/apache/cassandra/service/reads/repair/BlockingReadRepair.java +++ b/src/java/org/apache/cassandra/service/reads/repair/BlockingReadRepair.java @@ -39,6 +39,9 @@ import org.apache.cassandra.metrics.ReadRepairMetrics; import org.apache.cassandra.tracing.Tracing; +import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static java.util.concurrent.TimeUnit.NANOSECONDS; + /** * 'Classic' read repair. Doesn't allow the client read to return until * updates have been written to nodes needing correction. Breaks write @@ -84,7 +87,7 @@ public void awaitWrites() boolean timedOut = false; for (BlockingPartitionRepair repair: repairs) { - if (!repair.awaitRepairs(DatabaseDescriptor.getWriteRpcTimeout(), TimeUnit.MILLISECONDS)) + if (!repair.awaitRepairs(DatabaseDescriptor.getWriteRpcTimeout(NANOSECONDS), NANOSECONDS)) { timedOut = true; } diff --git a/src/java/org/apache/cassandra/service/reads/repair/BlockingReadRepairs.java b/src/java/org/apache/cassandra/service/reads/repair/BlockingReadRepairs.java index ceb176569044..3a4978e7cce0 100644 --- a/src/java/org/apache/cassandra/service/reads/repair/BlockingReadRepairs.java +++ b/src/java/org/apache/cassandra/service/reads/repair/BlockingReadRepairs.java @@ -54,7 +54,7 @@ public static Mutation createRepairMutation(PartitionUpdate update, ConsistencyL Keyspace keyspace = Keyspace.open(mutation.getKeyspaceName()); TableMetadata metadata = update.metadata(); - int messagingVersion = MessagingService.instance().getVersion(destination); + int messagingVersion = MessagingService.instance().versions.get(destination); int mutationSize = (int) Mutation.serializer.serializedSize(mutation, messagingVersion); int maxMutationSize = DatabaseDescriptor.getMaxMutationSize(); diff --git a/src/java/org/apache/cassandra/streaming/DefaultConnectionFactory.java b/src/java/org/apache/cassandra/streaming/DefaultConnectionFactory.java index b19280377823..5f2163f5410b 100644 --- a/src/java/org/apache/cassandra/streaming/DefaultConnectionFactory.java +++ b/src/java/org/apache/cassandra/streaming/DefaultConnectionFactory.java @@ -19,97 +19,40 @@ package org.apache.cassandra.streaming; import java.io.IOException; -import java.util.concurrent.TimeUnit; - -import javax.annotation.Nullable; import com.google.common.annotations.VisibleForTesting; -import com.google.common.util.concurrent.Uninterruptibles; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import io.netty.bootstrap.Bootstrap; import io.netty.channel.Channel; -import io.netty.channel.ChannelFuture; -import io.netty.channel.WriteBufferWaterMark; -import org.apache.cassandra.config.DatabaseDescriptor; -import org.apache.cassandra.config.EncryptionOptions.ServerEncryptionOptions; -import org.apache.cassandra.net.async.NettyFactory; -import org.apache.cassandra.net.async.OutboundConnectionIdentifier; -import org.apache.cassandra.net.async.OutboundConnectionParams; +import io.netty.channel.EventLoop; +import io.netty.util.concurrent.Future; +import org.apache.cassandra.net.ConnectionCategory; +import org.apache.cassandra.net.MessagingService; +import org.apache.cassandra.net.OutboundConnectionInitiator.Result; +import org.apache.cassandra.net.OutboundConnectionInitiator.Result.StreamingSuccess; +import org.apache.cassandra.net.OutboundConnectionSettings; + +import static org.apache.cassandra.net.OutboundConnectionInitiator.initiateStreaming; public class DefaultConnectionFactory implements StreamConnectionFactory { - private static final Logger logger = LoggerFactory.getLogger(DefaultConnectionFactory.class); - - private static final int DEFAULT_CHANNEL_BUFFER_SIZE = 1 << 22; - - @VisibleForTesting - public static long MAX_WAIT_TIME_NANOS = TimeUnit.SECONDS.toNanos(30); @VisibleForTesting public static int MAX_CONNECT_ATTEMPTS = 3; @Override - public Channel createConnection(OutboundConnectionIdentifier connectionId, int protocolVersion) throws IOException - { - ServerEncryptionOptions encryptionOptions = DatabaseDescriptor.getInternodeMessagingEncyptionOptions(); - - if (encryptionOptions.internode_encryption == ServerEncryptionOptions.InternodeEncryption.none) - encryptionOptions = null; - - return createConnection(connectionId, protocolVersion, encryptionOptions); - } - - protected Channel createConnection(OutboundConnectionIdentifier connectionId, int protocolVersion, @Nullable ServerEncryptionOptions encryptionOptions) throws IOException + public Channel createConnection(OutboundConnectionSettings template, int messagingVersion) throws IOException { - // this is the amount of data to allow in memory before netty sets the channel writablility flag to false - int channelBufferSize = DEFAULT_CHANNEL_BUFFER_SIZE; - WriteBufferWaterMark waterMark = new WriteBufferWaterMark(channelBufferSize >> 2, channelBufferSize); - - int sendBufferSize = DatabaseDescriptor.getInternodeSendBufferSize() > 0 - ? DatabaseDescriptor.getInternodeSendBufferSize() - : OutboundConnectionParams.DEFAULT_SEND_BUFFER_SIZE; - - int tcpConnectTimeout = DatabaseDescriptor.getInternodeTcpConnectTimeoutInMS(); - int tcpUserTimeout = DatabaseDescriptor.getInternodeTcpUserTimeoutInMS(); + EventLoop eventLoop = MessagingService.instance().socketFactory.outboundStreamingGroup().next(); - OutboundConnectionParams params = OutboundConnectionParams.builder() - .connectionId(connectionId) - .encryptionOptions(encryptionOptions) - .mode(NettyFactory.Mode.STREAMING) - .protocolVersion(protocolVersion) - .sendBufferSize(sendBufferSize) - .tcpConnectTimeoutInMS(tcpConnectTimeout) - .tcpUserTimeoutInMS(tcpUserTimeout) - .waterMark(waterMark) - .build(); - - Bootstrap bootstrap = NettyFactory.instance.createOutboundBootstrap(params); - - int connectionAttemptCount = 0; - long now = System.nanoTime(); - final long end = now + MAX_WAIT_TIME_NANOS; - final Channel channel; + int attempts = 0; while (true) { - ChannelFuture channelFuture = bootstrap.connect(); - channelFuture.awaitUninterruptibly(end - now, TimeUnit.MILLISECONDS); - if (channelFuture.isSuccess()) - { - channel = channelFuture.channel(); - break; - } + Future> result = initiateStreaming(eventLoop, template.withDefaults(ConnectionCategory.STREAMING), messagingVersion); + result.awaitUninterruptibly(); // initiate has its own timeout, so this is "guaranteed" to return relatively promptly + if (result.isSuccess()) + return result.getNow().success().channel; - connectionAttemptCount++; - now = System.nanoTime(); - if (connectionAttemptCount == MAX_CONNECT_ATTEMPTS || end - now <= 0) - throw new IOException("failed to connect to " + connectionId + " for streaming data", channelFuture.cause()); - - long waitms = DatabaseDescriptor.getRpcTimeout() * (long)Math.pow(2, connectionAttemptCount); - logger.warn("Failed attempt {} to connect to {}. Retrying in {} ms.", connectionAttemptCount, connectionId, waitms); - Uninterruptibles.sleepUninterruptibly(waitms, TimeUnit.MILLISECONDS); + if (++attempts == MAX_CONNECT_ATTEMPTS) + throw new IOException("failed to connect to " + template.to + " for streaming data", result.cause()); } - - return channel; } } diff --git a/src/java/org/apache/cassandra/streaming/ReplicationFinishedVerbHandler.java b/src/java/org/apache/cassandra/streaming/ReplicationDoneVerbHandler.java similarity index 67% rename from src/java/org/apache/cassandra/streaming/ReplicationFinishedVerbHandler.java rename to src/java/org/apache/cassandra/streaming/ReplicationDoneVerbHandler.java index ce8a921ad10f..7d73b114181d 100644 --- a/src/java/org/apache/cassandra/streaming/ReplicationFinishedVerbHandler.java +++ b/src/java/org/apache/cassandra/streaming/ReplicationDoneVerbHandler.java @@ -21,21 +21,20 @@ import org.slf4j.LoggerFactory; import org.apache.cassandra.net.IVerbHandler; -import org.apache.cassandra.net.MessageIn; -import org.apache.cassandra.net.MessageOut; +import org.apache.cassandra.net.Message; import org.apache.cassandra.net.MessagingService; import org.apache.cassandra.service.StorageService; -public class ReplicationFinishedVerbHandler implements IVerbHandler +public class ReplicationDoneVerbHandler implements IVerbHandler { - private static final Logger logger = LoggerFactory.getLogger(ReplicationFinishedVerbHandler.class); + public static ReplicationDoneVerbHandler instance = new ReplicationDoneVerbHandler(); - public void doVerb(MessageIn msg, int id) + private static final Logger logger = LoggerFactory.getLogger(ReplicationDoneVerbHandler.class); + + public void doVerb(Message msg) { - StorageService.instance.confirmReplication(msg.from); - MessageOut response = new MessageOut(MessagingService.Verb.INTERNAL_RESPONSE); - if (logger.isDebugEnabled()) - logger.debug("Replying to {}@{}", id, msg.from); - MessagingService.instance().sendReply(response, id, msg.from); + StorageService.instance.confirmReplication(msg.from()); + logger.debug("Replying to {}@{}", msg.id(), msg.from()); + MessagingService.instance().send(msg.emptyResponse(), msg.from()); } } diff --git a/src/java/org/apache/cassandra/streaming/SessionSummary.java b/src/java/org/apache/cassandra/streaming/SessionSummary.java index cf63a572cdb9..5b168a0841ed 100644 --- a/src/java/org/apache/cassandra/streaming/SessionSummary.java +++ b/src/java/org/apache/cassandra/streaming/SessionSummary.java @@ -28,7 +28,9 @@ import org.apache.cassandra.io.util.DataInputPlus; import org.apache.cassandra.io.util.DataOutputPlus; import org.apache.cassandra.locator.InetAddressAndPort; -import org.apache.cassandra.net.CompactEndpointSerializationHelper; +import org.apache.cassandra.locator.InetAddressAndPort.Serializer; + +import static org.apache.cassandra.locator.InetAddressAndPort.Serializer.inetAddressAndPortSerializer; public class SessionSummary { @@ -80,8 +82,8 @@ public int hashCode() { public void serialize(SessionSummary summary, DataOutputPlus out, int version) throws IOException { - CompactEndpointSerializationHelper.instance.serialize(summary.coordinator, out, version); - CompactEndpointSerializationHelper.instance.serialize(summary.peer, out, version); + inetAddressAndPortSerializer.serialize(summary.coordinator, out, version); + inetAddressAndPortSerializer.serialize(summary.peer, out, version); out.writeInt(summary.receivingSummaries.size()); for (StreamSummary streamSummary: summary.receivingSummaries) @@ -98,8 +100,8 @@ public void serialize(SessionSummary summary, DataOutputPlus out, int version) t public SessionSummary deserialize(DataInputPlus in, int version) throws IOException { - InetAddressAndPort coordinator = CompactEndpointSerializationHelper.instance.deserialize(in, version); - InetAddressAndPort peer = CompactEndpointSerializationHelper.instance.deserialize(in, version); + InetAddressAndPort coordinator = inetAddressAndPortSerializer.deserialize(in, version); + InetAddressAndPort peer = inetAddressAndPortSerializer.deserialize(in, version); int numRcvd = in.readInt(); List receivingSummaries = new ArrayList<>(numRcvd); @@ -121,8 +123,8 @@ public SessionSummary deserialize(DataInputPlus in, int version) throws IOExcept public long serializedSize(SessionSummary summary, int version) { long size = 0; - size += CompactEndpointSerializationHelper.instance.serializedSize(summary.coordinator, version); - size += CompactEndpointSerializationHelper.instance.serializedSize(summary.peer, version); + size += inetAddressAndPortSerializer.serializedSize(summary.coordinator, version); + size += inetAddressAndPortSerializer.serializedSize(summary.peer, version); size += TypeSizes.sizeof(summary.receivingSummaries.size()); for (StreamSummary streamSummary: summary.receivingSummaries) diff --git a/src/java/org/apache/cassandra/streaming/StreamConnectionFactory.java b/src/java/org/apache/cassandra/streaming/StreamConnectionFactory.java index 4cfe41e01251..95208e400bd6 100644 --- a/src/java/org/apache/cassandra/streaming/StreamConnectionFactory.java +++ b/src/java/org/apache/cassandra/streaming/StreamConnectionFactory.java @@ -21,9 +21,9 @@ import java.io.IOException; import io.netty.channel.Channel; -import org.apache.cassandra.net.async.OutboundConnectionIdentifier; +import org.apache.cassandra.net.OutboundConnectionSettings; public interface StreamConnectionFactory { - Channel createConnection(OutboundConnectionIdentifier connectionId, int protocolVersion) throws IOException; + Channel createConnection(OutboundConnectionSettings template, int messagingVersion) throws IOException; } diff --git a/src/java/org/apache/cassandra/streaming/StreamCoordinator.java b/src/java/org/apache/cassandra/streaming/StreamCoordinator.java index 7eada287b60b..6d757b6f31bc 100644 --- a/src/java/org/apache/cassandra/streaming/StreamCoordinator.java +++ b/src/java/org/apache/cassandra/streaming/StreamCoordinator.java @@ -19,11 +19,13 @@ import java.util.*; +import com.google.common.annotations.VisibleForTesting; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.apache.cassandra.locator.InetAddressAndPort; + /** * {@link StreamCoordinator} is a helper class that abstracts away maintaining multiple * StreamSession and ProgressInfo instances per peer. @@ -340,5 +342,11 @@ public Collection getAllSessionInfo() { return sessionInfos.values(); } + + @VisibleForTesting + public void shutdown() + { + streamSessions.values().forEach(ss -> ss.sessionFailed()); + } } } diff --git a/src/java/org/apache/cassandra/streaming/StreamReceiveTask.java b/src/java/org/apache/cassandra/streaming/StreamReceiveTask.java index 31c60bebc5e3..87d6ce0dd229 100644 --- a/src/java/org/apache/cassandra/streaming/StreamReceiveTask.java +++ b/src/java/org/apache/cassandra/streaming/StreamReceiveTask.java @@ -19,7 +19,10 @@ import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -28,6 +31,9 @@ import org.apache.cassandra.schema.TableId; import org.apache.cassandra.utils.JVMStabilityInspector; +import static org.apache.cassandra.utils.ExecutorUtils.awaitTermination; +import static org.apache.cassandra.utils.ExecutorUtils.shutdown; + /** * Task that manages receiving files for the session for certain ColumnFamily. */ @@ -158,4 +164,11 @@ public synchronized void abort() done = true; receiver.abort(); } + + @VisibleForTesting + public static void shutdownAndWait(long timeout, TimeUnit unit) throws InterruptedException, TimeoutException + { + shutdown(executor); + awaitTermination(timeout, unit, executor); + } } diff --git a/src/java/org/apache/cassandra/streaming/StreamRequest.java b/src/java/org/apache/cassandra/streaming/StreamRequest.java index f37268f200c4..0c8542fcfaa8 100644 --- a/src/java/org/apache/cassandra/streaming/StreamRequest.java +++ b/src/java/org/apache/cassandra/streaming/StreamRequest.java @@ -24,6 +24,7 @@ import java.util.List; import org.apache.cassandra.db.TypeSizes; +import org.apache.cassandra.dht.IPartitioner; import org.apache.cassandra.dht.Range; import org.apache.cassandra.dht.Token; import org.apache.cassandra.io.IVersionedSerializer; @@ -32,8 +33,8 @@ import org.apache.cassandra.locator.InetAddressAndPort; import org.apache.cassandra.locator.RangesAtEndpoint; import org.apache.cassandra.locator.Replica; -import org.apache.cassandra.net.CompactEndpointSerializationHelper; -import org.apache.cassandra.net.MessagingService; + +import static org.apache.cassandra.locator.InetAddressAndPort.Serializer.inetAddressAndPortSerializer; public class StreamRequest { @@ -67,7 +68,7 @@ public void serialize(StreamRequest request, DataOutputPlus out, int version) th out.writeUTF(request.keyspace); out.writeInt(request.columnFamilies.size()); - CompactEndpointSerializationHelper.streamingInstance.serialize(request.full.endpoint(), out, version); + inetAddressAndPortSerializer.serialize(request.full.endpoint(), out, version); serializeReplicas(request.full, out, version); serializeReplicas(request.transientReplicas, out, version); for (String cf : request.columnFamilies) @@ -80,7 +81,7 @@ private void serializeReplicas(RangesAtEndpoint replicas, DataOutputPlus out, in for (Replica replica : replicas) { - MessagingService.validatePartitioner(replica.range()); + IPartitioner.validate(replica.range()); Token.serializer.serialize(replica.range().left, out, version); Token.serializer.serialize(replica.range().right, out, version); } @@ -90,7 +91,7 @@ public StreamRequest deserialize(DataInputPlus in, int version) throws IOExcepti { String keyspace = in.readUTF(); int cfCount = in.readInt(); - InetAddressAndPort endpoint = CompactEndpointSerializationHelper.streamingInstance.deserialize(in, version); + InetAddressAndPort endpoint = inetAddressAndPortSerializer.deserialize(in, version); RangesAtEndpoint full = deserializeReplicas(in, version, endpoint, true); RangesAtEndpoint transientReplicas = deserializeReplicas(in, version, endpoint, false); @@ -110,8 +111,8 @@ RangesAtEndpoint deserializeReplicas(DataInputPlus in, int version, InetAddressA //TODO, super need to review the usage of streaming vs not streaming endpoint serialization helper //to make sure I'm not using the wrong one some of the time, like do repair messages use the //streaming version? - Token left = Token.serializer.deserialize(in, MessagingService.globalPartitioner(), version); - Token right = Token.serializer.deserialize(in, MessagingService.globalPartitioner(), version); + Token left = Token.serializer.deserialize(in, IPartitioner.global(), version); + Token right = Token.serializer.deserialize(in, IPartitioner.global(), version); replicas.add(new Replica(endpoint, new Range<>(left, right), isFull)); } return replicas.build(); @@ -121,7 +122,7 @@ public long serializedSize(StreamRequest request, int version) { int size = TypeSizes.sizeof(request.keyspace); size += TypeSizes.sizeof(request.columnFamilies.size()); - size += CompactEndpointSerializationHelper.streamingInstance.serializedSize(request.full.endpoint(), version); + size += inetAddressAndPortSerializer.serializedSize(request.full.endpoint(), version); size += replicasSerializedSize(request.transientReplicas, version); size += replicasSerializedSize(request.full, version); for (String cf : request.columnFamilies) diff --git a/src/java/org/apache/cassandra/streaming/StreamSession.java b/src/java/org/apache/cassandra/streaming/StreamSession.java index 08a1b078027a..cad1ba708db4 100644 --- a/src/java/org/apache/cassandra/streaming/StreamSession.java +++ b/src/java/org/apache/cassandra/streaming/StreamSession.java @@ -21,7 +21,6 @@ import java.util.*; import java.util.concurrent.*; import java.util.concurrent.atomic.AtomicBoolean; -import java.util.function.Function; import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.*; @@ -43,8 +42,7 @@ import org.apache.cassandra.locator.InetAddressAndPort; import org.apache.cassandra.locator.Replica; import org.apache.cassandra.metrics.StreamingMetrics; -import org.apache.cassandra.net.MessagingService; -import org.apache.cassandra.net.async.OutboundConnectionIdentifier; +import org.apache.cassandra.net.OutboundConnectionSettings; import org.apache.cassandra.schema.TableId; import org.apache.cassandra.streaming.async.NettyStreamingMessageSender; import org.apache.cassandra.streaming.messages.*; @@ -52,6 +50,7 @@ import org.apache.cassandra.utils.JVMStabilityInspector; import static com.google.common.collect.Iterables.all; +import static org.apache.cassandra.net.MessagingService.current_version; /** * Handles the streaming a one or more streams to and from a specific remote node. @@ -136,11 +135,7 @@ public class StreamSession implements IEndpointStateChangeSubscriber * Each {@code StreamSession} is identified by this InetAddressAndPort which is broadcast address of the node streaming. */ public final InetAddressAndPort peer; - - /** - * Preferred IP Address/Port of the peer; this is the address that will be connect to. Can be the same as {@linkplain #peer}. - */ - private final InetAddressAndPort preferredPeerInetAddressAndPort; + private final OutboundConnectionSettings template; private final int index; @@ -184,30 +179,25 @@ public enum State public StreamSession(StreamOperation streamOperation, InetAddressAndPort peer, StreamConnectionFactory factory, int index, UUID pendingRepair, PreviewKind previewKind) { - this(streamOperation, peer, factory, index, pendingRepair, previewKind, MessagingService.instance()::getPreferredRemoteAddr); + this(streamOperation, new OutboundConnectionSettings(peer), factory, index, pendingRepair, previewKind); } - - @VisibleForTesting - public StreamSession(StreamOperation streamOperation, InetAddressAndPort peer, StreamConnectionFactory factory, - int index, UUID pendingRepair, PreviewKind previewKind, - Function preferredIpMapper) + /** + * Create new streaming session with the peer. + */ + public StreamSession(StreamOperation streamOperation, OutboundConnectionSettings template, StreamConnectionFactory factory, + int index, UUID pendingRepair, PreviewKind previewKind) { this.streamOperation = streamOperation; - this.peer = peer; + this.peer = template.to; + this.template = template; this.index = index; - InetAddressAndPort preferredPeerEndpoint = preferredIpMapper.apply(peer); - this.preferredPeerInetAddressAndPort = (preferredPeerEndpoint == null) ? peer : preferredPeerEndpoint; - - OutboundConnectionIdentifier id = OutboundConnectionIdentifier.stream(InetAddressAndPort.getByAddressOverrideDefaults(FBUtilities.getJustLocalAddress(), 0), - preferredPeerInetAddressAndPort); - this.messageSender = new NettyStreamingMessageSender(this, id, factory, StreamMessage.CURRENT_VERSION, previewKind.isPreview()); - this.metrics = StreamingMetrics.get(preferredPeerInetAddressAndPort); + this.messageSender = new NettyStreamingMessageSender(this, template, factory, current_version, previewKind.isPreview()); + this.metrics = StreamingMetrics.get(peer); this.pendingRepair = pendingRepair; this.previewKind = previewKind; - logger.debug("Creating stream session peer={} preferredPeerInetAddressAndPort={}", peer, - preferredPeerInetAddressAndPort); + logger.debug("Creating stream session to {}", template); } public UUID planId() @@ -286,7 +276,7 @@ public void start() { logger.info("[Stream #{}] Starting streaming to {}{}", planId(), peer, - peer.equals(preferredPeerInetAddressAndPort) ? "" : " through " + preferredPeerInetAddressAndPort); + template.connectTo == null ? "" : " through " + template.connectTo); messageSender.initialize(); onInitializationComplete(); } @@ -545,7 +535,7 @@ private void logError(Throwable e) logger.error("[Stream #{}] Did not receive response from peer {}{} for {} secs. Is peer down? " + "If not, maybe try increasing streaming_keep_alive_period_in_secs.", planId(), peer.getHostAddress(true), - peer.equals(preferredPeerInetAddressAndPort) ? "" : " through " + preferredPeerInetAddressAndPort.getHostAddress(true), + template.connectTo == null ? "" : " through " + template.connectTo.getHostAddress(true), 2 * DatabaseDescriptor.getStreamingKeepAlivePeriod(), e); } @@ -553,7 +543,7 @@ private void logError(Throwable e) { logger.error("[Stream #{}] Streaming error occurred on session with peer {}{}", planId(), peer.getHostAddress(true), - peer.equals(preferredPeerInetAddressAndPort) ? "" : " through " + preferredPeerInetAddressAndPort.getHostAddress(true), + template.connectTo == null ? "" : " through " + template.connectTo.getHostAddress(true), e); } } @@ -707,7 +697,8 @@ public SessionInfo getSessionInfo() List transferSummaries = Lists.newArrayList(); for (StreamTask transfer : transfers.values()) transferSummaries.add(transfer.getSummary()); - return new SessionInfo(peer, index, preferredPeerInetAddressAndPort, receivingSummaries, transferSummaries, state); + // TODO: the connectTo treatment here is peculiar, and needs thinking about - since the connection factory can change it + return new SessionInfo(peer, index, template.connectTo == null ? peer : template.connectTo, receivingSummaries, transferSummaries, state); } public synchronized void taskCompleted(StreamReceiveTask completedTask) @@ -789,7 +780,7 @@ private void flushSSTables(Iterable stores) { List> flushes = new ArrayList<>(); for (ColumnFamilyStore cfs : stores) - flushes.add(cfs.forceFlush()); + flushes.add(cfs.forceFlushToSSTable()); FBUtilities.waitOnFutures(flushes); } diff --git a/src/java/org/apache/cassandra/streaming/StreamTransferTask.java b/src/java/org/apache/cassandra/streaming/StreamTransferTask.java index 802188a5c808..ba05acdefea4 100644 --- a/src/java/org/apache/cassandra/streaming/StreamTransferTask.java +++ b/src/java/org/apache/cassandra/streaming/StreamTransferTask.java @@ -34,6 +34,9 @@ import org.apache.cassandra.schema.TableId; import org.apache.cassandra.streaming.messages.OutgoingStreamMessage; +import static org.apache.cassandra.utils.ExecutorUtils.awaitTermination; +import static org.apache.cassandra.utils.ExecutorUtils.shutdown; + /** * StreamTransferTask sends streams for a given table */ @@ -178,4 +181,11 @@ public void run() assert prev == null; return future; } + + @VisibleForTesting + public static void shutdownAndWait(long timeout, TimeUnit units) throws InterruptedException, TimeoutException + { + shutdown(timeoutExecutor); + awaitTermination(timeout, units, timeoutExecutor); + } } diff --git a/src/java/org/apache/cassandra/streaming/StreamingMessageSender.java b/src/java/org/apache/cassandra/streaming/StreamingMessageSender.java index accf5548c4a0..96e76267ec49 100644 --- a/src/java/org/apache/cassandra/streaming/StreamingMessageSender.java +++ b/src/java/org/apache/cassandra/streaming/StreamingMessageSender.java @@ -20,7 +20,9 @@ import java.io.IOException; -import org.apache.cassandra.net.async.OutboundConnectionIdentifier; +import com.google.common.annotations.VisibleForTesting; + +import org.apache.cassandra.locator.InetAddressAndPort; import org.apache.cassandra.streaming.messages.StreamMessage; public interface StreamingMessageSender @@ -29,8 +31,6 @@ public interface StreamingMessageSender void sendMessage(StreamMessage message) throws IOException; - OutboundConnectionIdentifier getConnectionId(); - boolean connected(); void close(); diff --git a/src/java/org/apache/cassandra/streaming/async/NettyStreamingMessageSender.java b/src/java/org/apache/cassandra/streaming/async/NettyStreamingMessageSender.java index 8511a87975f5..1314e1d9d8a1 100644 --- a/src/java/org/apache/cassandra/streaming/async/NettyStreamingMessageSender.java +++ b/src/java/org/apache/cassandra/streaming/async/NettyStreamingMessageSender.java @@ -49,9 +49,9 @@ import org.apache.cassandra.config.DatabaseDescriptor; import org.apache.cassandra.io.util.DataOutputBufferFixed; import org.apache.cassandra.io.util.DataOutputStreamPlus; -import org.apache.cassandra.net.async.ByteBufDataOutputStreamPlus; -import org.apache.cassandra.net.async.NettyFactory; -import org.apache.cassandra.net.async.OutboundConnectionIdentifier; +import org.apache.cassandra.net.AsyncChannelPromise; +import org.apache.cassandra.net.OutboundConnectionSettings; +import org.apache.cassandra.net.AsyncStreamingOutputPlus; import org.apache.cassandra.streaming.StreamConnectionFactory; import org.apache.cassandra.streaming.StreamSession; import org.apache.cassandra.streaming.StreamingMessageSender; @@ -85,13 +85,15 @@ public class NettyStreamingMessageSender implements StreamingMessageSender private static final int DEFAULT_MAX_PARALLEL_TRANSFERS = FBUtilities.getAvailableProcessors(); private static final int MAX_PARALLEL_TRANSFERS = Integer.parseInt(System.getProperty(Config.PROPERTY_PREFIX + "streaming.session.parallelTransfers", Integer.toString(DEFAULT_MAX_PARALLEL_TRANSFERS))); + private static final long DEFAULT_CLOSE_WAIT_IN_MILLIS = TimeUnit.MINUTES.toMillis(5); + // a simple mechansim for allowing a degree of fairnes across multiple sessions private static final Semaphore fileTransferSemaphore = new Semaphore(DEFAULT_MAX_PARALLEL_TRANSFERS, true); private final StreamSession session; private final boolean isPreview; - private final int protocolVersion; - private final OutboundConnectionIdentifier connectionId; + private final int streamingVersion; + private final OutboundConnectionSettings template; private final StreamConnectionFactory factory; private volatile boolean closed; @@ -120,11 +122,11 @@ public class NettyStreamingMessageSender implements StreamingMessageSender @VisibleForTesting static final AttributeKey TRANSFERRING_FILE_ATTR = AttributeKey.valueOf("transferringFile"); - public NettyStreamingMessageSender(StreamSession session, OutboundConnectionIdentifier connectionId, StreamConnectionFactory factory, int protocolVersion, boolean isPreview) + public NettyStreamingMessageSender(StreamSession session, OutboundConnectionSettings template, StreamConnectionFactory factory, int streamingVersion, boolean isPreview) { this.session = session; - this.protocolVersion = protocolVersion; - this.connectionId = connectionId; + this.streamingVersion = streamingVersion; + this.template = template; this.factory = factory; this.isPreview = isPreview; @@ -181,9 +183,9 @@ private void scheduleKeepAliveTask(Channel channel) private Channel createChannel() throws IOException { - Channel channel = factory.createConnection(connectionId, protocolVersion); + Channel channel = factory.createConnection(template, streamingVersion); ChannelPipeline pipeline = channel.pipeline(); - pipeline.addLast(NettyFactory.instance.streamingGroup, NettyFactory.INBOUND_STREAM_HANDLER_NAME, new StreamingInboundHandler(connectionId.remote(), protocolVersion, session)); + pipeline.addLast("stream", new StreamingInboundHandler(template.to, streamingVersion, session)); channel.attr(TRANSFERRING_FILE_ATTR).set(Boolean.FALSE); logger.debug("Creating channel id {} local {} remote {}", channel.id(), channel.localAddress(), channel.remoteAddress()); return channel; @@ -238,7 +240,7 @@ private void sendControlMessage(Channel channel, StreamMessage message, GenericF logger.debug("{} Sending {}", createLogTag(session, channel), message); // we anticipate that the control messages are rather small, so allocating a ByteBuf shouldn't blow out of memory. - long messageSize = StreamMessage.serializedSize(message, protocolVersion); + long messageSize = StreamMessage.serializedSize(message, streamingVersion); if (messageSize > 1 << 30) { throw new IllegalStateException(String.format("%s something is seriously wrong with the calculated stream control message's size: %d bytes, type is %s", @@ -250,12 +252,11 @@ private void sendControlMessage(Channel channel, StreamMessage message, GenericF ByteBuffer nioBuf = buf.nioBuffer(0, (int) messageSize); @SuppressWarnings("resource") DataOutputBufferFixed out = new DataOutputBufferFixed(nioBuf); - StreamMessage.serialize(message, out, protocolVersion, session); + StreamMessage.serialize(message, out, streamingVersion, session); assert nioBuf.position() == nioBuf.limit(); buf.writerIndex(nioBuf.position()); - ChannelFuture channelFuture = channel.writeAndFlush(buf); - channelFuture.addListener(future -> listener.operationComplete(future)); + AsyncChannelPromise.writeAndFlush(channel, buf, listener); } /** @@ -275,7 +276,7 @@ java.util.concurrent.Future onControlMessageComplete(Future future, StreamMes Channel channel = channelFuture.channel(); logger.error("{} failed to send a stream message/data to peer {}: msg = {}", - createLogTag(session, channel), connectionId, msg, future.cause()); + createLogTag(session, channel), template.to, msg, future.cause()); // StreamSession will invoke close(), but we have to mark this sender as closed so the session doesn't try // to send any failure messages @@ -322,10 +323,9 @@ public void run() throw new IllegalStateException("channel's transferring state is currently set to true. refusing to start new stream"); // close the DataOutputStreamPlus as we're done with it - but don't close the channel - try (DataOutputStreamPlus outPlus = ByteBufDataOutputStreamPlus.create(session, channel, 1 << 20)) + try (DataOutputStreamPlus outPlus = new AsyncStreamingOutputPlus(channel)) { - StreamMessage.serialize(msg, outPlus, protocolVersion, session); - channel.flush(); + StreamMessage.serialize(msg, outPlus, streamingVersion, session); } finally { @@ -393,6 +393,18 @@ private Channel getOrCreateChannel() } } + private void onError(Throwable t) + { + try + { + session.onError(t).get(DEFAULT_CLOSE_WAIT_IN_MILLIS, TimeUnit.MILLISECONDS); + } + catch (Exception e) + { + // nop - let the Throwable param be the main failure point here, and let session handle it + } + } + /** * For testing purposes */ @@ -477,7 +489,7 @@ private void keepAliveListener(Future future) /** * For testing purposes only. */ - void setClosed() + public void setClosed() { closed = true; } @@ -495,7 +507,7 @@ int semaphoreAvailablePermits() @Override public boolean connected() { - return !closed; + return !closed && (controlMessageChannel == null || controlMessageChannel.isOpen()); } @Override @@ -503,7 +515,7 @@ public void close() { closed = true; if (logger.isDebugEnabled()) - logger.debug("{} Closing stream connection channels on {}", createLogTag(session, null), connectionId); + logger.debug("{} Closing stream connection channels on {}", createLogTag(session, null), template.to); for (ScheduledFuture future : channelKeepAlives) future.cancel(false); channelKeepAlives.clear(); @@ -518,10 +530,4 @@ public void close() if (controlMessageChannel != null) controlMessageChannel.close(); } - - @Override - public OutboundConnectionIdentifier getConnectionId() - { - return connectionId; - } } diff --git a/src/java/org/apache/cassandra/streaming/async/StreamCompressionSerializer.java b/src/java/org/apache/cassandra/streaming/async/StreamCompressionSerializer.java index ca15b78dc1bc..1d834097d9a9 100644 --- a/src/java/org/apache/cassandra/streaming/async/StreamCompressionSerializer.java +++ b/src/java/org/apache/cassandra/streaming/async/StreamCompressionSerializer.java @@ -27,6 +27,9 @@ import net.jpountz.lz4.LZ4Compressor; import net.jpountz.lz4.LZ4FastDecompressor; import org.apache.cassandra.io.util.DataInputPlus; +import org.apache.cassandra.net.AsyncStreamingOutputPlus; + +import static org.apache.cassandra.net.MessagingService.current_version; /** * A serialiazer for stream compressed files (see package-level documentation). Much like a typical compressed @@ -51,29 +54,20 @@ public StreamCompressionSerializer(ByteBufAllocator allocator) */ private static final int HEADER_LENGTH = 8; - /** - * @return A buffer with decompressed data. - */ - public ByteBuf serialize(LZ4Compressor compressor, ByteBuffer in, int version) + public static AsyncStreamingOutputPlus.Write serialize(LZ4Compressor compressor, ByteBuffer in, int version) { - final int uncompressedLength = in.remaining(); - int maxLength = compressor.maxCompressedLength(uncompressedLength); - ByteBuf out = allocator.directBuffer(maxLength); - try - { - ByteBuffer compressedNioBuffer = out.nioBuffer(HEADER_LENGTH, maxLength - HEADER_LENGTH); - compressor.compress(in, compressedNioBuffer); - final int compressedLength = compressedNioBuffer.position(); - out.setInt(0, compressedLength); - out.setInt(4, uncompressedLength); - out.writerIndex(HEADER_LENGTH + compressedLength); - } - catch (Exception e) - { - if (out != null) - out.release(); - } - return out; + assert version == current_version; + return bufferSupplier -> { + int uncompressedLength = in.remaining(); + int maxLength = compressor.maxCompressedLength(uncompressedLength); + ByteBuffer out = bufferSupplier.get(maxLength); + out.position(HEADER_LENGTH); + compressor.compress(in, out); + int compressedLength = out.position() - HEADER_LENGTH; + out.putInt(0, compressedLength); + out.putInt(4, uncompressedLength); + out.flip(); + }; } /** diff --git a/src/java/org/apache/cassandra/streaming/async/StreamingInboundHandler.java b/src/java/org/apache/cassandra/streaming/async/StreamingInboundHandler.java index 7c10ef96efdf..a319fea47c7e 100644 --- a/src/java/org/apache/cassandra/streaming/async/StreamingInboundHandler.java +++ b/src/java/org/apache/cassandra/streaming/async/StreamingInboundHandler.java @@ -20,12 +20,18 @@ import java.io.EOFException; import java.io.IOException; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.List; import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.TimeUnit; import java.util.function.Function; import javax.annotation.Nullable; import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.Lists; import com.google.common.util.concurrent.Uninterruptibles; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -37,16 +43,17 @@ import io.netty.util.ReferenceCountUtil; import io.netty.util.concurrent.FastThreadLocalThread; import org.apache.cassandra.locator.InetAddressAndPort; -import org.apache.cassandra.net.async.RebufferingByteBufDataInputPlus; +import org.apache.cassandra.net.AsyncStreamingInputPlus; +import org.apache.cassandra.net.AsyncStreamingInputPlus.InputTimeoutException; import org.apache.cassandra.streaming.StreamManager; import org.apache.cassandra.streaming.StreamReceiveException; import org.apache.cassandra.streaming.StreamResultFuture; import org.apache.cassandra.streaming.StreamSession; -import org.apache.cassandra.streaming.messages.StreamMessageHeader; import org.apache.cassandra.streaming.messages.IncomingStreamMessage; import org.apache.cassandra.streaming.messages.KeepAliveMessage; import org.apache.cassandra.streaming.messages.StreamInitMessage; import org.apache.cassandra.streaming.messages.StreamMessage; +import org.apache.cassandra.streaming.messages.StreamMessageHeader; import org.apache.cassandra.utils.JVMStabilityInspector; import static org.apache.cassandra.streaming.async.NettyStreamingMessageSender.createLogTag; @@ -59,11 +66,9 @@ public class StreamingInboundHandler extends ChannelInboundHandlerAdapter { private static final Logger logger = LoggerFactory.getLogger(StreamingInboundHandler.class); - static final Function DEFAULT_SESSION_PROVIDER = sid -> StreamManager.instance.findSession(sid.from, sid.planId, sid.sessionIndex); - - private static final int AUTO_READ_LOW_WATER_MARK = 1 << 15; - private static final int AUTO_READ_HIGH_WATER_MARK = 1 << 20; - + private static final Function DEFAULT_SESSION_PROVIDER = sid -> StreamManager.instance.findSession(sid.from, sid.planId, sid.sessionIndex); + private static volatile boolean trackInboundHandlers = false; + private static Collection inboundHandlers; private final InetAddressAndPort remoteAddress; private final int protocolVersion; @@ -74,10 +79,10 @@ public class StreamingInboundHandler extends ChannelInboundHandlerAdapter * structure, and then consumed. *

* For thread safety, this structure's resources are released on the consuming thread - * (via {@link RebufferingByteBufDataInputPlus#close()}, - * but the producing side calls {@link RebufferingByteBufDataInputPlus#markClose()} to notify the input that is should close. + * (via {@link AsyncStreamingInputPlus#close()}, + * but the producing side calls {@link AsyncStreamingInputPlus#requestClosure()} to notify the input that is should close. */ - private RebufferingByteBufDataInputPlus buffers; + private AsyncStreamingInputPlus buffers; private volatile boolean closed; @@ -86,13 +91,15 @@ public StreamingInboundHandler(InetAddressAndPort remoteAddress, int protocolVer this.remoteAddress = remoteAddress; this.protocolVersion = protocolVersion; this.session = session; + if (trackInboundHandlers) + inboundHandlers.add(this); } @Override @SuppressWarnings("resource") public void handlerAdded(ChannelHandlerContext ctx) { - buffers = new RebufferingByteBufDataInputPlus(AUTO_READ_LOW_WATER_MARK, AUTO_READ_HIGH_WATER_MARK, ctx.channel().config()); + buffers = new AsyncStreamingInputPlus(ctx.channel()); Thread blockingIOThread = new FastThreadLocalThread(new StreamDeserializingTask(DEFAULT_SESSION_PROVIDER, session, ctx.channel()), String.format("Stream-Deserializer-%s-%s", remoteAddress.toString(), ctx.channel().id())); blockingIOThread.setDaemon(true); @@ -102,9 +109,7 @@ public void handlerAdded(ChannelHandlerContext ctx) @Override public void channelRead(ChannelHandlerContext ctx, Object message) { - if (!closed && message instanceof ByteBuf) - buffers.append((ByteBuf) message); - else + if (closed || !(message instanceof ByteBuf) || !buffers.append((ByteBuf) message)) ReferenceCountUtil.release(message); } @@ -118,7 +123,9 @@ public void channelInactive(ChannelHandlerContext ctx) void close() { closed = true; - buffers.markClose(); + buffers.requestClosure(); + if (trackInboundHandlers) + inboundHandlers.remove(this); } @Override @@ -134,7 +141,7 @@ public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) /** * For testing only!! */ - void setPendingBuffers(RebufferingByteBufDataInputPlus bufChannel) + void setPendingBuffers(AsyncStreamingInputPlus bufChannel) { this.buffers = bufChannel; } @@ -164,9 +171,11 @@ public void run() { while (true) { + buffers.maybeIssueRead(); + // do a check of available bytes and possibly sleep some amount of time (then continue). // this way we can break out of run() sanely or we end up blocking indefintely in StreamMessage.deserialize() - while (buffers.available() == 0) + while (buffers.isEmpty()) { if (closed) return; @@ -190,10 +199,11 @@ public void run() if (logger.isDebugEnabled()) logger.debug("{} Received {}", createLogTag(session, channel), message); + session.messageReceived(message); } } - catch (EOFException eof) + catch (InputTimeoutException | EOFException e) { // ignore } @@ -219,11 +229,17 @@ else if (t instanceof StreamReceiveException) closed = true; if (buffers != null) + { + // request closure again as the original request could have raced with receiving a + // message and been consumed in the message receive loop above. Otherweise + // buffers could hang indefinitely on the queue.poll. + buffers.requestClosure(); buffers.close(); + } } } - StreamSession deriveSession(StreamMessage message) throws IOException + StreamSession deriveSession(StreamMessage message) { StreamSession streamSession = null; // StreamInitMessage starts a new channel, and IncomingStreamMessage potentially, as well. @@ -267,4 +283,24 @@ static class SessionIdentifier this.sessionIndex = sessionIndex; } } + + /** Shutdown for in-JVM tests. For any other usage, tracking of active inbound streaming handlers + * should be revisted first and in-JVM shutdown refactored with it. + * This does not prevent new inbound handlers being added after shutdown, nor is not thread-safe + * around new inbound handlers being opened during shutdown. + */ + @VisibleForTesting + public static void shutdown() + { + assert trackInboundHandlers == true : "in-JVM tests required tracking of inbound streaming handlers"; + + inboundHandlers.forEach(StreamingInboundHandler::close); + inboundHandlers.clear(); + } + + public static void trackInboundHandlers() + { + inboundHandlers = Collections.newSetFromMap(new ConcurrentHashMap<>()); + trackInboundHandlers = true; + } } diff --git a/src/java/org/apache/cassandra/streaming/compress/ByteBufCompressionDataOutputStreamPlus.java b/src/java/org/apache/cassandra/streaming/compress/ByteBufCompressionDataOutputStreamPlus.java deleted file mode 100644 index 3f1b22b53ce0..000000000000 --- a/src/java/org/apache/cassandra/streaming/compress/ByteBufCompressionDataOutputStreamPlus.java +++ /dev/null @@ -1,76 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.cassandra.streaming.compress; - -import java.io.IOException; -import java.nio.ByteBuffer; - -import io.netty.buffer.ByteBuf; -import net.jpountz.lz4.LZ4Compressor; -import net.jpountz.lz4.LZ4Factory; -import org.apache.cassandra.io.util.DataOutputStreamPlus; -import org.apache.cassandra.io.util.WrappedDataOutputStreamPlus; -import org.apache.cassandra.net.async.ByteBufDataOutputStreamPlus; -import org.apache.cassandra.streaming.StreamManager.StreamRateLimiter; -import org.apache.cassandra.streaming.async.StreamCompressionSerializer; -import org.apache.cassandra.streaming.messages.StreamMessage; - -/** - * The intent of this class is to only be used in a very narrow use-case: on the stream compression path of streaming. - * This class should really only get calls to {@link #write(ByteBuffer)}, where the incoming buffer is compressed and sent - * downstream. - */ -public class ByteBufCompressionDataOutputStreamPlus extends WrappedDataOutputStreamPlus -{ - private final StreamRateLimiter limiter; - private final LZ4Compressor compressor; - private final StreamCompressionSerializer serializer; - - public ByteBufCompressionDataOutputStreamPlus(DataOutputStreamPlus out, StreamRateLimiter limiter) - { - super(out); - assert out instanceof ByteBufDataOutputStreamPlus; - compressor = LZ4Factory.fastestInstance().fastCompressor(); - serializer = new StreamCompressionSerializer(((ByteBufDataOutputStreamPlus)out).getAllocator()); - this.limiter = limiter; - } - - /** - * {@inheritDoc} - * - * Compress the incoming buffer and send the result downstream. The buffer parameter will not be used nor passed - * to downstream components, and thus callers can safely free the buffer upon return. - */ - @Override - public void write(ByteBuffer buffer) throws IOException - { - ByteBuf compressed = serializer.serialize(compressor, buffer, StreamMessage.CURRENT_VERSION); - - // this is a blocking call - you have been warned - limiter.acquire(compressed.readableBytes()); - - ((ByteBufDataOutputStreamPlus)out).writeToChannel(compressed); - } - - @Override - public void close() - { - // explicitly overriding close() to avoid closing the wrapped stream; it will be closed via other means - } -} diff --git a/src/java/org/apache/cassandra/streaming/compress/StreamCompressionInputStream.java b/src/java/org/apache/cassandra/streaming/compress/StreamCompressionInputStream.java index daf6d28e9424..50d746aeb66f 100644 --- a/src/java/org/apache/cassandra/streaming/compress/StreamCompressionInputStream.java +++ b/src/java/org/apache/cassandra/streaming/compress/StreamCompressionInputStream.java @@ -28,7 +28,7 @@ import net.jpountz.lz4.LZ4FastDecompressor; import org.apache.cassandra.io.util.DataInputPlus; import org.apache.cassandra.io.util.RebufferingInputStream; -import org.apache.cassandra.net.async.RebufferingByteBufDataInputPlus; +import org.apache.cassandra.net.AsyncStreamingInputPlus; import org.apache.cassandra.streaming.async.StreamCompressionSerializer; public class StreamCompressionInputStream extends RebufferingInputStream implements AutoCloseable @@ -56,8 +56,8 @@ public StreamCompressionInputStream(DataInputPlus dataInputPlus, int protocolVer this.protocolVersion = protocolVersion; this.decompressor = LZ4Factory.fastestInstance().fastDecompressor(); - ByteBufAllocator allocator = dataInputPlus instanceof RebufferingByteBufDataInputPlus - ? ((RebufferingByteBufDataInputPlus)dataInputPlus).getAllocator() + ByteBufAllocator allocator = dataInputPlus instanceof AsyncStreamingInputPlus + ? ((AsyncStreamingInputPlus)dataInputPlus).getAllocator() : PooledByteBufAllocator.DEFAULT; deserializer = new StreamCompressionSerializer(allocator); } diff --git a/src/java/org/apache/cassandra/streaming/messages/StreamInitMessage.java b/src/java/org/apache/cassandra/streaming/messages/StreamInitMessage.java index a591a43f5ce1..e14879039017 100644 --- a/src/java/org/apache/cassandra/streaming/messages/StreamInitMessage.java +++ b/src/java/org/apache/cassandra/streaming/messages/StreamInitMessage.java @@ -24,13 +24,14 @@ import org.apache.cassandra.io.util.DataInputPlus; import org.apache.cassandra.io.util.DataOutputStreamPlus; import org.apache.cassandra.locator.InetAddressAndPort; -import org.apache.cassandra.net.CompactEndpointSerializationHelper; import org.apache.cassandra.net.MessagingService; import org.apache.cassandra.streaming.StreamOperation; import org.apache.cassandra.streaming.PreviewKind; import org.apache.cassandra.streaming.StreamSession; import org.apache.cassandra.utils.UUIDSerializer; +import static org.apache.cassandra.locator.InetAddressAndPort.Serializer.inetAddressAndPortSerializer; + /** * StreamInitMessage is first sent from the node where {@link org.apache.cassandra.streaming.StreamSession} is started, * to initiate corresponding {@link org.apache.cassandra.streaming.StreamSession} on the other side. @@ -72,7 +73,7 @@ private static class StreamInitMessageSerializer implements Serializer /** StreamMessage types */ public enum Type { - PREPARE_SYN(1, 5, PrepareSynMessage.serializer), - STREAM(2, 0, IncomingStreamMessage.serializer, OutgoingStreamMessage.serializer), - RECEIVED(3, 4, ReceivedMessage.serializer), - COMPLETE(5, 1, CompleteMessage.serializer), - SESSION_FAILED(6, 5, SessionFailedMessage.serializer), - KEEP_ALIVE(7, 5, KeepAliveMessage.serializer), - PREPARE_SYNACK(8, 5, PrepareSynAckMessage.serializer), - PREPARE_ACK(9, 5, PrepareAckMessage.serializer), - STREAM_INIT(10, 5, StreamInitMessage.serializer); - - public static Type get(byte type) + PREPARE_SYN (1, 5, PrepareSynMessage.serializer ), + STREAM (2, 0, IncomingStreamMessage.serializer, OutgoingStreamMessage.serializer), + RECEIVED (3, 4, ReceivedMessage.serializer ), + COMPLETE (5, 1, CompleteMessage.serializer ), + SESSION_FAILED (6, 5, SessionFailedMessage.serializer), + KEEP_ALIVE (7, 5, KeepAliveMessage.serializer ), + PREPARE_SYNACK (8, 5, PrepareSynAckMessage.serializer), + PREPARE_ACK (9, 5, PrepareAckMessage.serializer ), + STREAM_INIT (10, 5, StreamInitMessage.serializer ); + + private static final Type[] idToTypeMap; + + static { - for (Type t : Type.values()) + Type[] values = values(); + + int max = Integer.MIN_VALUE; + for (Type t : values) + max = max(t.id, max); + + Type[] idMap = new Type[max + 1]; + for (Type t : values) { - if (t.type == type) - return t; + if (idMap[t.id] != null) + throw new RuntimeException("Two StreamMessage Types map to the same id: " + t.id); + idMap[t.id] = t; } - throw new IllegalArgumentException("Unknown type " + type); + + idToTypeMap = idMap; } - private final byte type; + public static Type lookupById(int id) + { + if (id < 0 || id >= idToTypeMap.length) + throw new IllegalArgumentException("Invalid type id: " + id); + + return idToTypeMap[id]; + } + + public final int id; public final int priority; + public final Serializer inSerializer; public final Serializer outSerializer; - @SuppressWarnings("unchecked") - private Type(int type, int priority, Serializer serializer) + Type(int id, int priority, Serializer serializer) { - this(type, priority, serializer, serializer); + this(id, priority, serializer, serializer); } @SuppressWarnings("unchecked") - private Type(int type, int priority, Serializer inSerializer, Serializer outSerializer) + Type(int id, int priority, Serializer inSerializer, Serializer outSerializer) { - this.type = (byte) type; + if (id < 0 || id > Byte.MAX_VALUE) + throw new IllegalArgumentException("StreamMessage Type id must be non-negative and less than " + Byte.MAX_VALUE); + + this.id = id; this.priority = priority; this.inSerializer = inSerializer; this.outSerializer = outSerializer; diff --git a/src/java/org/apache/cassandra/streaming/messages/StreamMessageHeader.java b/src/java/org/apache/cassandra/streaming/messages/StreamMessageHeader.java index 84cf3a392aa0..e76777a00e51 100644 --- a/src/java/org/apache/cassandra/streaming/messages/StreamMessageHeader.java +++ b/src/java/org/apache/cassandra/streaming/messages/StreamMessageHeader.java @@ -24,12 +24,13 @@ import org.apache.cassandra.io.util.DataInputPlus; import org.apache.cassandra.io.util.DataOutputPlus; import org.apache.cassandra.locator.InetAddressAndPort; -import org.apache.cassandra.net.CompactEndpointSerializationHelper; import org.apache.cassandra.net.MessagingService; import org.apache.cassandra.schema.TableId; import org.apache.cassandra.streaming.StreamSession; import org.apache.cassandra.utils.UUIDSerializer; +import static org.apache.cassandra.locator.InetAddressAndPort.Serializer.inetAddressAndPortSerializer; + /** * StreamingFileHeader is appended before sending actual data to describe what it's sending. */ @@ -102,7 +103,7 @@ static class FileMessageHeaderSerializer public void serialize(StreamMessageHeader header, DataOutputPlus out, int version) throws IOException { header.tableId.serialize(out); - CompactEndpointSerializationHelper.streamingInstance.serialize(header.sender, out, version); + inetAddressAndPortSerializer.serialize(header.sender, out, version); UUIDSerializer.serializer.serialize(header.planId, out, version); out.writeInt(header.sessionIndex); out.writeInt(header.sequenceNumber); @@ -117,7 +118,7 @@ public void serialize(StreamMessageHeader header, DataOutputPlus out, int versio public StreamMessageHeader deserialize(DataInputPlus in, int version) throws IOException { TableId tableId = TableId.deserialize(in); - InetAddressAndPort sender = CompactEndpointSerializationHelper.streamingInstance.deserialize(in, version); + InetAddressAndPort sender = inetAddressAndPortSerializer.deserialize(in, version); UUID planId = UUIDSerializer.serializer.deserialize(in, MessagingService.current_version); int sessionIndex = in.readInt(); int sequenceNumber = in.readInt(); @@ -130,7 +131,7 @@ public StreamMessageHeader deserialize(DataInputPlus in, int version) throws IOE public long serializedSize(StreamMessageHeader header, int version) { long size = header.tableId.serializedSize(); - size += CompactEndpointSerializationHelper.streamingInstance.serializedSize(header.sender, version); + size += inetAddressAndPortSerializer.serializedSize(header.sender, version); size += UUIDSerializer.serializer.serializedSize(header.planId, version); size += TypeSizes.sizeof(header.sessionIndex); size += TypeSizes.sizeof(header.sequenceNumber); diff --git a/src/java/org/apache/cassandra/tools/BulkLoadConnectionFactory.java b/src/java/org/apache/cassandra/tools/BulkLoadConnectionFactory.java index cce686fb0103..177f811898e7 100644 --- a/src/java/org/apache/cassandra/tools/BulkLoadConnectionFactory.java +++ b/src/java/org/apache/cassandra/tools/BulkLoadConnectionFactory.java @@ -19,18 +19,16 @@ package org.apache.cassandra.tools; import java.io.IOException; -import java.net.InetSocketAddress; import io.netty.channel.Channel; -import org.apache.cassandra.config.DatabaseDescriptor; import org.apache.cassandra.config.EncryptionOptions; -import org.apache.cassandra.net.async.OutboundConnectionIdentifier; +import org.apache.cassandra.net.OutboundConnectionSettings; import org.apache.cassandra.streaming.DefaultConnectionFactory; -import org.apache.cassandra.locator.InetAddressAndPort; import org.apache.cassandra.streaming.StreamConnectionFactory; public class BulkLoadConnectionFactory extends DefaultConnectionFactory implements StreamConnectionFactory { + // TODO: what is this unused variable for? private final boolean outboundBindAny; private final int secureStoragePort; private final EncryptionOptions.ServerEncryptionOptions encryptionOptions; @@ -38,21 +36,19 @@ public class BulkLoadConnectionFactory extends DefaultConnectionFactory implemen public BulkLoadConnectionFactory(int secureStoragePort, EncryptionOptions.ServerEncryptionOptions encryptionOptions, boolean outboundBindAny) { this.secureStoragePort = secureStoragePort; - this.encryptionOptions = encryptionOptions != null && encryptionOptions.internode_encryption == EncryptionOptions.ServerEncryptionOptions.InternodeEncryption.none - ? null - : encryptionOptions; + this.encryptionOptions = encryptionOptions; this.outboundBindAny = outboundBindAny; } - public Channel createConnection(OutboundConnectionIdentifier connectionId, int protocolVersion) throws IOException + public Channel createConnection(OutboundConnectionSettings template, int messagingVersion) throws IOException { // Connect to secure port for all peers if ServerEncryptionOptions is configured other than 'none' // When 'all', 'dc' and 'rack', server nodes always have SSL port open, and since thin client like sstableloader // does not know which node is in which dc/rack, connecting to SSL port is always the option. - int port = encryptionOptions != null && encryptionOptions.internode_encryption != EncryptionOptions.ServerEncryptionOptions.InternodeEncryption.none ? - secureStoragePort : connectionId.remote().port; - connectionId = connectionId.withNewConnectionAddress(InetAddressAndPort.getByAddressOverrideDefaults(connectionId.remote().address, port)); - return createConnection(connectionId, protocolVersion, encryptionOptions); + if (encryptionOptions != null && encryptionOptions.internode_encryption != EncryptionOptions.ServerEncryptionOptions.InternodeEncryption.none) + template = template.withConnectTo(template.to.withPort(secureStoragePort)); + + return super.createConnection(template, messagingVersion); } } diff --git a/src/java/org/apache/cassandra/tools/BulkLoader.java b/src/java/org/apache/cassandra/tools/BulkLoader.java index d85c6054ba6d..2ca2a3d5b273 100644 --- a/src/java/org/apache/cassandra/tools/BulkLoader.java +++ b/src/java/org/apache/cassandra/tools/BulkLoader.java @@ -265,7 +265,7 @@ private static SSLOptions buildSSLOptions(EncryptionOptions clientEncryptionOpti return JdkSSLOptions.builder() .withSSLContext(sslContext) - .withCipherSuites(clientEncryptionOptions.cipher_suites) + .withCipherSuites(clientEncryptionOptions.cipher_suites.toArray(new String[0])) .build(); } diff --git a/src/java/org/apache/cassandra/tools/LoaderOptions.java b/src/java/org/apache/cassandra/tools/LoaderOptions.java index d6cb670655e3..7ad3299e298a 100644 --- a/src/java/org/apache/cassandra/tools/LoaderOptions.java +++ b/src/java/org/apache/cassandra/tools/LoaderOptions.java @@ -468,50 +468,49 @@ public Builder parseArgs(String cmdArgs[]) if (cmd.hasOption(SSL_TRUSTSTORE) || cmd.hasOption(SSL_TRUSTSTORE_PW) || cmd.hasOption(SSL_KEYSTORE) || cmd.hasOption(SSL_KEYSTORE_PW)) { - clientEncOptions.enabled = true; + clientEncOptions = clientEncOptions.withEnabled(true); } if (cmd.hasOption(SSL_TRUSTSTORE)) { - clientEncOptions.truststore = cmd.getOptionValue(SSL_TRUSTSTORE); + clientEncOptions = clientEncOptions.withTrustStore(cmd.getOptionValue(SSL_TRUSTSTORE)); } if (cmd.hasOption(SSL_TRUSTSTORE_PW)) { - clientEncOptions.truststore_password = cmd.getOptionValue(SSL_TRUSTSTORE_PW); + clientEncOptions = clientEncOptions.withTrustStorePassword(cmd.getOptionValue(SSL_TRUSTSTORE_PW)); } if (cmd.hasOption(SSL_KEYSTORE)) { - clientEncOptions.keystore = cmd.getOptionValue(SSL_KEYSTORE); // if a keystore was provided, lets assume we'll need to use - // it - clientEncOptions.require_client_auth = true; + clientEncOptions = clientEncOptions.withKeyStore(cmd.getOptionValue(SSL_KEYSTORE)) + .withRequireClientAuth(true); } if (cmd.hasOption(SSL_KEYSTORE_PW)) { - clientEncOptions.keystore_password = cmd.getOptionValue(SSL_KEYSTORE_PW); + clientEncOptions = clientEncOptions.withKeyStorePassword(cmd.getOptionValue(SSL_KEYSTORE_PW)); } if (cmd.hasOption(SSL_PROTOCOL)) { - clientEncOptions.protocol = cmd.getOptionValue(SSL_PROTOCOL); + clientEncOptions = clientEncOptions.withProtocol(cmd.getOptionValue(SSL_PROTOCOL)); } if (cmd.hasOption(SSL_ALGORITHM)) { - clientEncOptions.algorithm = cmd.getOptionValue(SSL_ALGORITHM); + clientEncOptions = clientEncOptions.withAlgorithm(cmd.getOptionValue(SSL_ALGORITHM)); } if (cmd.hasOption(SSL_STORE_TYPE)) { - clientEncOptions.store_type = cmd.getOptionValue(SSL_STORE_TYPE); + clientEncOptions = clientEncOptions.withStoreType(cmd.getOptionValue(SSL_STORE_TYPE)); } if (cmd.hasOption(SSL_CIPHER_SUITES)) { - clientEncOptions.cipher_suites = cmd.getOptionValue(SSL_CIPHER_SUITES).split(","); + clientEncOptions = clientEncOptions.withCipherSuites(cmd.getOptionValue(SSL_CIPHER_SUITES).split(",")); } if (cmd.hasOption(TARGET_KEYSPACE)) diff --git a/src/java/org/apache/cassandra/tools/SSTableMetadataViewer.java b/src/java/org/apache/cassandra/tools/SSTableMetadataViewer.java index 8ff964f0d57a..8f7e8a5dbd8d 100755 --- a/src/java/org/apache/cassandra/tools/SSTableMetadataViewer.java +++ b/src/java/org/apache/cassandra/tools/SSTableMetadataViewer.java @@ -398,7 +398,7 @@ private void printSStableMetadata(String fname, boolean scan) throws IOException String::valueOf); rowSize.printHistogram(out, color, unicode); field("Column Count", ""); - TermHistogram cellCount = new TermHistogram(stats.estimatedColumnCount, + TermHistogram cellCount = new TermHistogram(stats.estimatedCellPerPartitionCount, "Columns", String::valueOf, String::valueOf); diff --git a/src/java/org/apache/cassandra/tools/nodetool/Snapshot.java b/src/java/org/apache/cassandra/tools/nodetool/Snapshot.java index 8d01d3a2b96b..495ee9dae539 100644 --- a/src/java/org/apache/cassandra/tools/nodetool/Snapshot.java +++ b/src/java/org/apache/cassandra/tools/nodetool/Snapshot.java @@ -71,7 +71,7 @@ public void execute(NodeProbe probe) else { throw new IOException( - "When specifying the Keyspace columfamily list for a snapshot, you should not specify columnfamily"); + "When specifying the Keyspace table list (using -kt,--kt-list,-kc,--kc.list), you must not also specify keyspaces to snapshot"); } if (!snapshotName.isEmpty()) sb.append(" with snapshot name [").append(snapshotName).append("]"); diff --git a/src/java/org/apache/cassandra/tools/nodetool/Status.java b/src/java/org/apache/cassandra/tools/nodetool/Status.java index 21868e722220..8c37022f5173 100644 --- a/src/java/org/apache/cassandra/tools/nodetool/Status.java +++ b/src/java/org/apache/cassandra/tools/nodetool/Status.java @@ -25,18 +25,16 @@ import java.net.UnknownHostException; import java.text.DecimalFormat; import java.util.Collection; -import java.util.HashSet; import java.util.List; import java.util.Map; -import java.util.Set; import java.util.SortedMap; -import java.util.function.ToIntFunction; import org.apache.cassandra.locator.EndpointSnitchInfoMBean; import org.apache.cassandra.locator.InetAddressAndPort; import org.apache.cassandra.tools.NodeProbe; import org.apache.cassandra.tools.NodeTool; import org.apache.cassandra.tools.NodeTool.NodeToolCmd; +import org.apache.cassandra.tools.nodetool.formatter.TableBuilder; import com.google.common.collect.ArrayListMultimap; @@ -51,7 +49,6 @@ public class Status extends NodeToolCmd private boolean resolveIp = false; private boolean isTokenPerNode = true; - private String format = null; private Collection joiningNodes, leavingNodes, movingNodes, liveNodes, unreachableNodes; private Map loadMap, hostIDMap; private EndpointSnitchInfoMBean epSnitchInfo; @@ -70,6 +67,7 @@ public void execute(NodeProbe probe) epSnitchInfo = probe.getEndpointSnitchInfoProxy(); StringBuilder errors = new StringBuilder(); + TableBuilder tableBuilder = new TableBuilder(" "); if (printPort) { @@ -97,8 +95,6 @@ public void execute(NodeProbe probe) if (dcs.size() < tokensToEndpoints.size()) isTokenPerNode = false; - int maxAddressLength = findMaxAddressLength(dcs, s -> s.ipOrDns().length()); - // Datacenters for (Map.Entry dc : dcs.entrySet()) { @@ -111,7 +107,7 @@ public void execute(NodeProbe probe) System.out.println("Status=Up/Down"); System.out.println("|/ State=Normal/Leaving/Joining/Moving"); - printNodesHeader(hasEffectiveOwns, isTokenPerNode, maxAddressLength); + addNodesHeader(hasEffectiveOwns, tableBuilder); ArrayListMultimap hostToTokens = ArrayListMultimap.create(); for (HostStatWithPort stat : dc.getValue()) @@ -121,10 +117,11 @@ public void execute(NodeProbe probe) { Float owns = ownerships.get(endpoint.toString()); List tokens = hostToTokens.get(endpoint); - printNodeWithPort(endpoint.toString(), owns, tokens, hasEffectiveOwns, isTokenPerNode, maxAddressLength); + addNodeWithPort(endpoint.toString(), owns, tokens, hasEffectiveOwns, tableBuilder); } } + tableBuilder.printTo(System.out); System.out.printf("%n" + errors); } else @@ -153,8 +150,6 @@ public void execute(NodeProbe probe) if (dcs.values().size() < tokensToEndpoints.keySet().size()) isTokenPerNode = false; - int maxAddressLength = findMaxAddressLength(dcs, s -> s.ipOrDns().length()); - // Datacenters for (Map.Entry dc : dcs.entrySet()) { @@ -167,7 +162,7 @@ public void execute(NodeProbe probe) System.out.println("Status=Up/Down"); System.out.println("|/ State=Normal/Leaving/Joining/Moving"); - printNodesHeader(hasEffectiveOwns, isTokenPerNode, maxAddressLength); + addNodesHeader(hasEffectiveOwns, tableBuilder); ArrayListMultimap hostToTokens = ArrayListMultimap.create(); for (HostStat stat : dc.getValue()) @@ -177,43 +172,29 @@ public void execute(NodeProbe probe) { Float owns = ownerships.get(endpoint); List tokens = hostToTokens.get(endpoint); - printNode(endpoint.getHostAddress(), owns, tokens, hasEffectiveOwns, isTokenPerNode, maxAddressLength); + addNode(endpoint.getHostAddress(), owns, tokens, hasEffectiveOwns, tableBuilder); } } + tableBuilder.printTo(System.out); System.out.printf("%n" + errors); } } - private , U> int findMaxAddressLength(Map dcs, ToIntFunction computeLength) + private void addNodesHeader(boolean hasEffectiveOwns, TableBuilder tableBuilder) { - int maxAddressLength = 0; - - Set seenHosts = new HashSet<>(); - for (T stats : dcs.values()) - for (U stat : stats) - if (seenHosts.add(stat)) - maxAddressLength = Math.max(maxAddressLength, computeLength.applyAsInt(stat)); - - return maxAddressLength; - } - - private void printNodesHeader(boolean hasEffectiveOwns, boolean isTokenPerNode, int maxAddressLength) - { - String fmt = getFormat(hasEffectiveOwns, isTokenPerNode, maxAddressLength); String owns = hasEffectiveOwns ? "Owns (effective)" : "Owns"; if (isTokenPerNode) - System.out.printf(fmt, "-", "-", "Address", "Load", owns, "Host ID", "Token", "Rack"); + tableBuilder.add("--", "Address", "Load", owns, "Host ID", "Token", "Rack"); else - System.out.printf(fmt, "-", "-", "Address", "Load", "Tokens", owns, "Host ID", "Rack"); + tableBuilder.add("--", "Address", "Load", "Tokens", owns, "Host ID", "Rack"); } - private void printNode(String endpoint, Float owns, String epDns, String token, int size, boolean hasEffectiveOwns, - boolean isTokenPerNode, int maxAddressLength) + private void addNode(String endpoint, Float owns, String epDns, String token, int size, boolean hasEffectiveOwns, + TableBuilder tableBuilder) { - String status, state, load, strOwns, hostID, rack, fmt; - fmt = getFormat(hasEffectiveOwns, isTokenPerNode, maxAddressLength); + String status, state, load, strOwns, hostID, rack; if (liveNodes.contains(endpoint)) status = "U"; else if (unreachableNodes.contains(endpoint)) status = "D"; else status = "?"; @@ -222,6 +203,7 @@ private void printNode(String endpoint, Float owns, String epDns, String token, else if (movingNodes.contains(endpoint)) state = "M"; else state = "N"; + String statusAndState = status.concat(state); load = loadMap.getOrDefault(endpoint, "?"); strOwns = owns != null && hasEffectiveOwns ? new DecimalFormat("##0.0%").format(owns) : "?"; hostID = hostIDMap.get(endpoint); @@ -235,48 +217,24 @@ private void printNode(String endpoint, Float owns, String epDns, String token, } if (isTokenPerNode) - System.out.printf(fmt, status, state, epDns, load, strOwns, hostID, token, rack); + { + tableBuilder.add(statusAndState, epDns, load, strOwns, hostID, token, rack); + } else - System.out.printf(fmt, status, state, epDns, load, size, strOwns, hostID, rack); - } - - private void printNode(String endpoint, Float owns, List tokens, boolean hasEffectiveOwns, - boolean isTokenPerNode, int maxAddressLength) - { - printNode(endpoint, owns, tokens.get(0).ipOrDns(), tokens.get(0).token, tokens.size(), hasEffectiveOwns, - isTokenPerNode, maxAddressLength); + { + tableBuilder.add(statusAndState, epDns, load, String.valueOf(size), strOwns, hostID, rack); + } } - private void printNodeWithPort(String endpoint, Float owns, List tokens, boolean hasEffectiveOwns, - boolean isTokenPerNode, int maxAddressLength) + private void addNode(String endpoint, Float owns, List tokens, boolean hasEffectiveOwns, + TableBuilder tableBuilder) { - printNode(endpoint, owns, tokens.get(0).ipOrDns(), tokens.get(0).token, tokens.size(), hasEffectiveOwns, - isTokenPerNode, maxAddressLength); + addNode(endpoint, owns, tokens.get(0).ipOrDns(), tokens.get(0).token, tokens.size(), hasEffectiveOwns, tableBuilder); } - private String getFormat(boolean hasEffectiveOwns, boolean isTokenPerNode, int maxAddressLength) + private void addNodeWithPort(String endpoint, Float owns, List tokens, boolean hasEffectiveOwns, + TableBuilder tableBuilder) { - if (format == null) - { - StringBuilder buf = new StringBuilder(); - String addressPlaceholder = String.format("%%-%ds ", maxAddressLength); - buf.append("%s%s "); // status - buf.append(addressPlaceholder); // address - buf.append("%-9s "); // load - if (!isTokenPerNode) - buf.append("%-11s "); // "Tokens" - if (hasEffectiveOwns) - buf.append("%-16s "); // "Owns (effective)" - else - buf.append("%-6s "); // "Owns - buf.append("%-36s "); // Host ID - if (isTokenPerNode) - buf.append("%-39s "); // token - buf.append("%s%n"); // "Rack" - - format = buf.toString(); - } - - return format; + addNode(endpoint, owns, tokens.get(0).ipOrDns(), tokens.get(0).token, tokens.size(), hasEffectiveOwns, tableBuilder); } } diff --git a/src/java/org/apache/cassandra/tools/nodetool/TableHistograms.java b/src/java/org/apache/cassandra/tools/nodetool/TableHistograms.java index f24c8a369e64..cb3b9463df91 100644 --- a/src/java/org/apache/cassandra/tools/nodetool/TableHistograms.java +++ b/src/java/org/apache/cassandra/tools/nodetool/TableHistograms.java @@ -23,17 +23,19 @@ import io.airlift.airline.Command; import java.util.ArrayList; -import java.util.Arrays; -import java.util.HashMap; import java.util.Iterator; import java.util.List; import java.util.Map; +import com.google.common.collect.HashMultimap; +import com.google.common.collect.Multimap; + import org.apache.cassandra.db.ColumnFamilyStoreMBean; import org.apache.cassandra.metrics.CassandraMetricsRegistry; import org.apache.cassandra.tools.NodeProbe; import org.apache.cassandra.tools.NodeTool.NodeToolCmd; import org.apache.cassandra.utils.EstimatedHistogram; + import org.apache.commons.lang3.ArrayUtils; @Command(name = "tablehistograms", description = "Print statistic histograms for a given table") @@ -45,40 +47,46 @@ public class TableHistograms extends NodeToolCmd @Override public void execute(NodeProbe probe) { - Map> tablesList = new HashMap<>(); + Multimap tablesList = HashMultimap.create(); + + // a > mapping for verification or as reference if none provided + Multimap allTables = HashMultimap.create(); + Iterator> tableMBeans = probe.getColumnFamilyStoreMBeanProxies(); + while (tableMBeans.hasNext()) + { + Map.Entry entry = tableMBeans.next(); + allTables.put(entry.getKey(), entry.getValue().getTableName()); + } + if (args.size() == 2) { - tablesList.put(args.get(0), new ArrayList(Arrays.asList(args.get(1)))); + tablesList.put(args.get(0), args.get(1)); } else if (args.size() == 1) { String[] input = args.get(0).split("\\."); checkArgument(input.length == 2, "tablehistograms requires keyspace and table name arguments"); - tablesList.put(input[0], new ArrayList(Arrays.asList(input[1]))); + tablesList.put(input[0], input[1]); } else { - // get a list of table stores - Iterator> tableMBeans = probe.getColumnFamilyStoreMBeanProxies(); - while (tableMBeans.hasNext()) + // use all tables + tablesList = allTables; + } + + // verify that all tables to list exist + for (String keyspace : tablesList.keys()) + { + for (String table : tablesList.get(keyspace)) { - Map.Entry entry = tableMBeans.next(); - String keyspaceName = entry.getKey(); - ColumnFamilyStoreMBean tableProxy = entry.getValue(); - if (!tablesList.containsKey(keyspaceName)) - { - tablesList.put(keyspaceName, new ArrayList()); - } - tablesList.get(keyspaceName).add(tableProxy.getTableName()); + if (!allTables.containsEntry(keyspace, table)) + throw new IllegalArgumentException("Unknown table " + keyspace + '.' + table); } } - Iterator>> iter = tablesList.entrySet().iterator(); - while(iter.hasNext()) + for (String keyspace : tablesList.keys()) { - Map.Entry> entry = iter.next(); - String keyspace = entry.getKey(); - for (String table : entry.getValue()) + for (String table : tablesList.get(keyspace)) { // calculate percentile of row size and column count long[] estimatedPartitionSize = (long[]) probe.getColumnFamilyMetric(keyspace, table, "EstimatedPartitionSizeHistogram"); diff --git a/src/java/org/apache/cassandra/tools/nodetool/formatter/TableBuilder.java b/src/java/org/apache/cassandra/tools/nodetool/formatter/TableBuilder.java index a56e52eb35c5..bf06d99ce562 100644 --- a/src/java/org/apache/cassandra/tools/nodetool/formatter/TableBuilder.java +++ b/src/java/org/apache/cassandra/tools/nodetool/formatter/TableBuilder.java @@ -41,8 +41,8 @@ */ public class TableBuilder { - // column delimiter char - private final char columnDelimiter; + // column delimiter + private final String columnDelimiter; private int[] maximumColumnWidth; private final List rows = new ArrayList<>(); @@ -53,6 +53,11 @@ public TableBuilder() } public TableBuilder(char columnDelimiter) + { + this(String.valueOf(columnDelimiter)); + } + + public TableBuilder(String columnDelimiter) { this.columnDelimiter = columnDelimiter; } diff --git a/src/java/org/apache/cassandra/tracing/Tracing.java b/src/java/org/apache/cassandra/tracing/Tracing.java index 55e36a4a7a7f..5891981f400d 100644 --- a/src/java/org/apache/cassandra/tracing/Tracing.java +++ b/src/java/org/apache/cassandra/tracing/Tracing.java @@ -23,13 +23,11 @@ import java.net.InetAddress; import java.nio.ByteBuffer; import java.util.Collections; -import java.util.List; import java.util.Map; import java.util.UUID; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; -import com.google.common.collect.ImmutableList; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -41,9 +39,8 @@ import org.apache.cassandra.io.util.DataInputPlus; import org.apache.cassandra.io.util.DataOutputPlus; import org.apache.cassandra.locator.InetAddressAndPort; -import org.apache.cassandra.net.MessageIn; -import org.apache.cassandra.net.MessagingService; -import org.apache.cassandra.net.ParameterType; +import org.apache.cassandra.net.Message; +import org.apache.cassandra.net.ParamType; import org.apache.cassandra.utils.FBUtilities; import org.apache.cassandra.utils.JVMStabilityInspector; import org.apache.cassandra.utils.UUIDGen; @@ -246,12 +243,11 @@ public TraceState begin(final String request, final Map paramete /** * Determines the tracing context from a message. Does NOT set the threadlocal state. * - * @param message The internode message + * @param header The internode message header */ - public TraceState initializeFromMessage(final MessageIn message) + public TraceState initializeFromMessage(final Message.Header header) { - final UUID sessionId = (UUID)message.parameters.get(ParameterType.TRACE_SESSION); - + final UUID sessionId = header.traceSession(); if (sessionId == null) return null; @@ -259,31 +255,60 @@ public TraceState initializeFromMessage(final MessageIn message) if (ts != null && ts.acquireReference()) return ts; - TraceType tmpType; - TraceType traceType = TraceType.QUERY; - if ((tmpType = (TraceType)message.parameters.get(ParameterType.TRACE_TYPE)) != null) - traceType = tmpType; + TraceType traceType = header.traceType(); - if (message.verb == MessagingService.Verb.REQUEST_RESPONSE) + if (header.verb.isResponse()) { // received a message for a session we've already closed out. see CASSANDRA-5668 - return new ExpiredTraceState(newTraceState(message.from, sessionId, traceType)); + return new ExpiredTraceState(newTraceState(header.from, sessionId, traceType)); } else { - ts = newTraceState(message.from, sessionId, traceType); + ts = newTraceState(header.from, sessionId, traceType); sessions.put(sessionId, ts); return ts; } } - public List getTraceHeaders() + /** + * Record any tracing data, if enabled on this message. + */ + public void traceOutgoingMessage(Message message, InetAddressAndPort sendTo) + { + try + { + final UUID sessionId = message.traceSession(); + if (sessionId == null) + return; + + String logMessage = String.format("Sending %s message to %s", message.verb(), sendTo); + + TraceState state = get(sessionId); + if (state == null) // session may have already finished; see CASSANDRA-5668 + { + TraceType traceType = message.traceType(); + trace(ByteBuffer.wrap(UUIDGen.decompose(sessionId)), logMessage, traceType.getTTL()); + } + else + { + state.trace(logMessage); + if (message.verb().isResponse()) + doneWithNonLocalSession(state); + } + } + catch (Exception e) + { + logger.warn("failed to capture the tracing info for an outbound message to {}, ignoring", sendTo, e); + } + } + + public Map addTraceHeaders(Map addToMutable) { assert isTracing(); - return ImmutableList.of( - ParameterType.TRACE_SESSION, Tracing.instance.getSessionId(), - ParameterType.TRACE_TYPE, Tracing.instance.getTraceType()); + addToMutable.put(ParamType.TRACE_SESSION, Tracing.instance.getSessionId()); + addToMutable.put(ParamType.TRACE_TYPE, Tracing.instance.getTraceType()); + return addToMutable; } protected abstract TraceState newTraceState(InetAddressAndPort coordinator, UUID sessionId, Tracing.TraceType traceType); diff --git a/src/java/org/apache/cassandra/transport/Connection.java b/src/java/org/apache/cassandra/transport/Connection.java index 908e7e964852..b7f5b170fff9 100644 --- a/src/java/org/apache/cassandra/transport/Connection.java +++ b/src/java/org/apache/cassandra/transport/Connection.java @@ -30,6 +30,7 @@ public class Connection private final Tracker tracker; private volatile FrameBodyTransformer transformer; + private boolean throwOnOverload; public Connection(Channel channel, ProtocolVersion version, Tracker tracker) { @@ -50,6 +51,16 @@ public FrameBodyTransformer getTransformer() return transformer; } + public void setThrowOnOverload(boolean throwOnOverload) + { + this.throwOnOverload = throwOnOverload; + } + + public boolean isThrowOnOverload() + { + return throwOnOverload; + } + public Tracker getTracker() { return tracker; diff --git a/src/java/org/apache/cassandra/transport/ConnectionLimitHandler.java b/src/java/org/apache/cassandra/transport/ConnectionLimitHandler.java index 7bcf280cd04c..3b2765f1e02e 100644 --- a/src/java/org/apache/cassandra/transport/ConnectionLimitHandler.java +++ b/src/java/org/apache/cassandra/transport/ConnectionLimitHandler.java @@ -22,6 +22,8 @@ import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInboundHandlerAdapter; import org.apache.cassandra.config.DatabaseDescriptor; +import org.apache.cassandra.utils.NoSpamLogger; + import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -29,6 +31,7 @@ import java.net.InetSocketAddress; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicLong; @@ -40,6 +43,8 @@ final class ConnectionLimitHandler extends ChannelInboundHandlerAdapter { private static final Logger logger = LoggerFactory.getLogger(ConnectionLimitHandler.class); + private static final NoSpamLogger noSpamLogger = NoSpamLogger.getLogger(logger, 1L, TimeUnit.MINUTES); + private final ConcurrentMap connectionsPerClient = new ConcurrentHashMap<>(); private final AtomicLong counter = new AtomicLong(0); @@ -56,7 +61,7 @@ public void channelActive(ChannelHandlerContext ctx) throws Exception if (count > limit) { // The decrement will be done in channelClosed(...) - logger.warn("Exceeded maximum native connection limit of {} by using {} connections", limit, count); + noSpamLogger.warn("Exceeded maximum native connection limit of {} by using {} connections", limit, count); ctx.close(); } else @@ -80,7 +85,7 @@ public void channelActive(ChannelHandlerContext ctx) throws Exception if (perIpCount.incrementAndGet() > perIpLimit) { // The decrement will be done in channelClosed(...) - logger.warn("Exceeded maximum native connection limit per ip of {} by using {} connections", perIpLimit, perIpCount); + noSpamLogger.warn("Exceeded maximum native connection limit per ip of {} by using {} connections", perIpLimit, perIpCount); ctx.close(); return; } diff --git a/src/java/org/apache/cassandra/transport/Frame.java b/src/java/org/apache/cassandra/transport/Frame.java index d3c810b2b924..8163d7ab1f6b 100644 --- a/src/java/org/apache/cassandra/transport/Frame.java +++ b/src/java/org/apache/cassandra/transport/Frame.java @@ -72,7 +72,7 @@ public boolean release() public static Frame create(Message.Type type, int streamId, ProtocolVersion version, EnumSet flags, ByteBuf body) { - Header header = new Header(version, flags, streamId, type); + Header header = new Header(version, flags, streamId, type, body.readableBytes()); return new Frame(header, body); } @@ -87,13 +87,15 @@ public static class Header public final EnumSet flags; public final int streamId; public final Message.Type type; + public final long bodySizeInBytes; - private Header(ProtocolVersion version, EnumSet flags, int streamId, Message.Type type) + private Header(ProtocolVersion version, EnumSet flags, int streamId, Message.Type type, long bodySizeInBytes) { this.version = version; this.flags = flags; this.streamId = streamId; this.type = type; + this.bodySizeInBytes = bodySizeInBytes; } public enum Flag @@ -227,7 +229,7 @@ Frame decodeFrame(ByteBuf buffer) idx += bodyLength; buffer.readerIndex(idx); - return new Frame(new Header(version, decodedFlags, streamId, type), body); + return new Frame(new Header(version, decodedFlags, streamId, type, bodyLength), body); } @Override diff --git a/src/java/org/apache/cassandra/transport/Message.java b/src/java/org/apache/cassandra/transport/Message.java index 0571478c67fd..99c012770100 100644 --- a/src/java/org/apache/cassandra/transport/Message.java +++ b/src/java/org/apache/cassandra/transport/Message.java @@ -42,6 +42,11 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.apache.cassandra.concurrent.LocalAwareExecutorService; +import org.apache.cassandra.config.DatabaseDescriptor; +import org.apache.cassandra.exceptions.OverloadedException; +import org.apache.cassandra.metrics.ClientMetrics; +import org.apache.cassandra.net.ResourceLimits; import org.apache.cassandra.service.ClientWarn; import org.apache.cassandra.service.StorageService; import org.apache.cassandra.tracing.Tracing; @@ -50,6 +55,8 @@ import org.apache.cassandra.utils.JVMStabilityInspector; import org.apache.cassandra.utils.UUIDGen; +import static org.apache.cassandra.concurrent.SharedExecutorPool.SHARED; + /** * A message from the CQL binary protocol. */ @@ -452,19 +459,42 @@ public void encode(ChannelHandlerContext ctx, Message message, List results) } } - @ChannelHandler.Sharable public static class Dispatcher extends SimpleChannelInboundHandler { + private static final LocalAwareExecutorService requestExecutor = SHARED.newExecutor(DatabaseDescriptor.getNativeTransportMaxThreads(), + Integer.MAX_VALUE, + "transport", + "Native-Transport-Requests"); + + /** + * Current count of *request* bytes that are live on the channel. + * + * Note: should only be accessed while on the netty event loop. + */ + private long channelPayloadBytesInFlight; + + private final Server.EndpointPayloadTracker endpointPayloadTracker; + + private boolean paused; + private static class FlushItem { final ChannelHandlerContext ctx; final Object response; final Frame sourceFrame; - private FlushItem(ChannelHandlerContext ctx, Object response, Frame sourceFrame) + final Dispatcher dispatcher; + + private FlushItem(ChannelHandlerContext ctx, Object response, Frame sourceFrame, Dispatcher dispatcher) { this.ctx = ctx; this.sourceFrame = sourceFrame; this.response = response; + this.dispatcher = dispatcher; + } + + public void release() + { + dispatcher.releaseItem(this); } } @@ -520,7 +550,7 @@ public void run() for (ChannelHandlerContext channel : channels) channel.flush(); for (FlushItem item : flushed) - item.sourceFrame.release(); + item.release(); channels.clear(); flushed.clear(); @@ -572,7 +602,7 @@ public void run() for (ChannelHandlerContext channel : channels) channel.flush(); for (FlushItem item : flushed) - item.sourceFrame.release(); + item.release(); channels.clear(); flushed.clear(); @@ -584,16 +614,98 @@ public void run() private final boolean useLegacyFlusher; - public Dispatcher(boolean useLegacyFlusher) + public Dispatcher(boolean useLegacyFlusher, Server.EndpointPayloadTracker endpointPayloadTracker) { super(false); this.useLegacyFlusher = useLegacyFlusher; + this.endpointPayloadTracker = endpointPayloadTracker; } @Override public void channelRead0(ChannelHandlerContext ctx, Request request) { + // if we decide to handle this message, process it outside of the netty event loop + if (shouldHandleRequest(ctx, request)) + requestExecutor.submit(() -> processRequest(ctx, request)); + } + + /** This check for inflight payload to potentially discard the request should have been ideally in one of the + * first handlers in the pipeline (Frame::decode()). However, incase of any exception thrown between that + * handler (where inflight payload is incremented) and this handler (Dispatcher::channelRead0) (where inflight + * payload in decremented), inflight payload becomes erroneous. ExceptionHandler is not sufficient for this + * purpose since it does not have the frame associated with the exception. + * + * Note: this method should execute on the netty event loop. + */ + private boolean shouldHandleRequest(ChannelHandlerContext ctx, Request request) + { + long frameSize = request.getSourceFrame().header.bodySizeInBytes; + + ResourceLimits.EndpointAndGlobal endpointAndGlobalPayloadsInFlight = endpointPayloadTracker.endpointAndGlobalPayloadsInFlight; + + // check for overloaded state by trying to allocate framesize to inflight payload trackers + if (endpointAndGlobalPayloadsInFlight.tryAllocate(frameSize) != ResourceLimits.Outcome.SUCCESS) + { + if (request.connection.isThrowOnOverload()) + { + // discard the request and throw an exception + ClientMetrics.instance.markRequestDiscarded(); + logger.trace("Discarded request of size: {}. InflightChannelRequestPayload: {}, InflightEndpointRequestPayload: {}, InflightOverallRequestPayload: {}, Request: {}", + frameSize, + channelPayloadBytesInFlight, + endpointAndGlobalPayloadsInFlight.endpoint().using(), + endpointAndGlobalPayloadsInFlight.global().using(), + request); + throw ErrorMessage.wrap(new OverloadedException("Server is in overloaded state. Cannot accept more requests at this point"), + request.getSourceFrame().header.streamId); + } + else + { + // set backpressure on the channel, and handle the request + endpointAndGlobalPayloadsInFlight.allocate(frameSize); + ctx.channel().config().setAutoRead(false); + ClientMetrics.instance.pauseConnection(); + paused = true; + } + } + + channelPayloadBytesInFlight += frameSize; + return true; + } + + /** + * Note: this method will be used in the {@link Flusher#run()}, which executes on the netty event loop + * ({@link Dispatcher#flusherLookup}). Thus, we assume the semantics and visibility of variables + * of being on the event loop. + */ + private void releaseItem(FlushItem item) + { + long itemSize = item.sourceFrame.header.bodySizeInBytes; + item.sourceFrame.release(); + + // since the request has been processed, decrement inflight payload at channel, endpoint and global levels + channelPayloadBytesInFlight -= itemSize; + ResourceLimits.Outcome endpointGlobalReleaseOutcome = endpointPayloadTracker.endpointAndGlobalPayloadsInFlight.release(itemSize); + + // now check to see if we need to reenable the channel's autoRead. + // If the current payload side is zero, we must reenable autoread as + // 1) we allow no other thread/channel to do it, and + // 2) there's no other events following this one (becuase we're at zero bytes in flight), + // so no successive to trigger the other clause in this if-block + ChannelConfig config = item.ctx.channel().config(); + if (paused && (channelPayloadBytesInFlight == 0 || endpointGlobalReleaseOutcome == ResourceLimits.Outcome.BELOW_LIMIT)) + { + paused = false; + ClientMetrics.instance.unpauseConnection(); + config.setAutoRead(true); + } + } + /** + * Note: this method is not expected to execute on the netty event loop. + */ + void processRequest(ChannelHandlerContext ctx, Request request) + { final Response response; final ServerConnection connection; long queryStartNanoTime = System.nanoTime(); @@ -619,7 +731,7 @@ public void channelRead0(ChannelHandlerContext ctx, Request request) { JVMStabilityInspector.inspectThrowable(t); UnexpectedChannelExceptionHandler handler = new UnexpectedChannelExceptionHandler(ctx.channel(), true); - flush(new FlushItem(ctx, ErrorMessage.fromException(t, handler).setStreamId(request.getStreamId()), request.getSourceFrame())); + flush(new FlushItem(ctx, ErrorMessage.fromException(t, handler).setStreamId(request.getStreamId()), request.getSourceFrame(), this)); return; } finally @@ -628,7 +740,19 @@ public void channelRead0(ChannelHandlerContext ctx, Request request) } logger.trace("Responding: {}, v={}", response, connection.getVersion()); - flush(new FlushItem(ctx, response, request.getSourceFrame())); + flush(new FlushItem(ctx, response, request.getSourceFrame(), this)); + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) + { + endpointPayloadTracker.release(); + if (paused) + { + paused = false; + ClientMetrics.instance.unpauseConnection(); + } + ctx.fireChannelInactive(); } private void flush(FlushItem item) @@ -646,6 +770,14 @@ private void flush(FlushItem item) flusher.queued.add(item); flusher.start(); } + + public static void shutdown() + { + if (requestExecutor != null) + { + requestExecutor.shutdown(); + } + } } @ChannelHandler.Sharable diff --git a/src/java/org/apache/cassandra/transport/ProtocolVersion.java b/src/java/org/apache/cassandra/transport/ProtocolVersion.java index 546983ffafd5..5c8c299952a8 100644 --- a/src/java/org/apache/cassandra/transport/ProtocolVersion.java +++ b/src/java/org/apache/cassandra/transport/ProtocolVersion.java @@ -104,13 +104,13 @@ public static ProtocolVersion decode(int versionNum, boolean allowOlderProtocols // if this is not a supported version check the old versions for (ProtocolVersion version : UNSUPPORTED) { - // if it is an old version that is no longer supported this ensures that we reply + // if it is an old version that is no longer supported this ensures that we respond // with that same version if (version.num == versionNum) throw new ProtocolException(ProtocolVersion.invalidVersionMessage(versionNum), version); } - // If the version is invalid reply with the highest version that we support + // If the version is invalid response with the highest version that we support throw new ProtocolException(invalidVersionMessage(versionNum), MAX_SUPPORTED_VERSION); } diff --git a/src/java/org/apache/cassandra/transport/ProtocolVersionTracker.java b/src/java/org/apache/cassandra/transport/ProtocolVersionTracker.java index 72bb9012d9cc..f2893778c70d 100644 --- a/src/java/org/apache/cassandra/transport/ProtocolVersionTracker.java +++ b/src/java/org/apache/cassandra/transport/ProtocolVersionTracker.java @@ -25,7 +25,6 @@ import com.github.benmanes.caffeine.cache.Cache; import com.github.benmanes.caffeine.cache.Caffeine; import com.github.benmanes.caffeine.cache.LoadingCache; -import org.apache.cassandra.utils.Clock; /** * This class tracks the last 100 connections per protocol version @@ -48,13 +47,13 @@ private ProtocolVersionTracker(int capacity) for (ProtocolVersion version : ProtocolVersion.values()) { clientsByProtocolVersion.put(version, Caffeine.newBuilder().maximumSize(capacity) - .build(key -> Clock.instance.currentTimeMillis())); + .build(key -> System.currentTimeMillis())); } } void addConnection(InetAddress addr, ProtocolVersion version) { - clientsByProtocolVersion.get(version).put(addr, Clock.instance.currentTimeMillis()); + clientsByProtocolVersion.get(version).put(addr, System.currentTimeMillis()); } List getAll() diff --git a/src/java/org/apache/cassandra/transport/RequestThreadPoolExecutor.java b/src/java/org/apache/cassandra/transport/RequestThreadPoolExecutor.java deleted file mode 100644 index 75dd05ddf878..000000000000 --- a/src/java/org/apache/cassandra/transport/RequestThreadPoolExecutor.java +++ /dev/null @@ -1,96 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.cassandra.transport; - -import java.util.List; -import java.util.concurrent.TimeUnit; - -import io.netty.util.concurrent.AbstractEventExecutor; -import io.netty.util.concurrent.EventExecutorGroup; -import io.netty.util.concurrent.Future; -import org.apache.cassandra.concurrent.LocalAwareExecutorService; -import org.apache.cassandra.config.DatabaseDescriptor; - -import static org.apache.cassandra.concurrent.SharedExecutorPool.SHARED; - -public class RequestThreadPoolExecutor extends AbstractEventExecutor -{ - private final static int MAX_QUEUED_REQUESTS = Integer.getInteger("cassandra.max_queued_native_transport_requests", 128); - private final static String THREAD_FACTORY_ID = "Native-Transport-Requests"; - private final LocalAwareExecutorService wrapped = SHARED.newExecutor(DatabaseDescriptor.getNativeTransportMaxThreads(), - MAX_QUEUED_REQUESTS, - "transport", - THREAD_FACTORY_ID); - - public boolean isShuttingDown() - { - return wrapped.isShutdown(); - } - - public Future shutdownGracefully(long l, long l2, TimeUnit timeUnit) - { - throw new IllegalStateException(); - } - - public Future terminationFuture() - { - throw new IllegalStateException(); - } - - @Override - public void shutdown() - { - wrapped.shutdown(); - } - - @Override - public List shutdownNow() - { - return wrapped.shutdownNow(); - } - - public boolean isShutdown() - { - return wrapped.isShutdown(); - } - - public boolean isTerminated() - { - return wrapped.isTerminated(); - } - - public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedException - { - return wrapped.awaitTermination(timeout, unit); - } - - public EventExecutorGroup parent() - { - return null; - } - - public boolean inEventLoop(Thread thread) - { - return false; - } - - public void execute(Runnable command) - { - wrapped.execute(command); - } -} diff --git a/src/java/org/apache/cassandra/transport/Server.java b/src/java/org/apache/cassandra/transport/Server.java index 33cd0fb3ae2c..c4690f157872 100644 --- a/src/java/org/apache/cassandra/transport/Server.java +++ b/src/java/org/apache/cassandra/transport/Server.java @@ -23,7 +23,10 @@ import java.net.UnknownHostException; import java.util.*; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.ConcurrentMap; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -41,6 +44,9 @@ import io.netty.handler.codec.ByteToMessageDecoder; import io.netty.handler.ssl.SslContext; import io.netty.handler.ssl.SslHandler; +import io.netty.handler.timeout.IdleState; +import io.netty.handler.timeout.IdleStateEvent; +import io.netty.handler.timeout.IdleStateHandler; import io.netty.util.Version; import io.netty.util.concurrent.EventExecutor; import io.netty.util.concurrent.GlobalEventExecutor; @@ -51,6 +57,7 @@ import org.apache.cassandra.config.EncryptionOptions; import org.apache.cassandra.db.marshal.AbstractType; import org.apache.cassandra.locator.InetAddressAndPort; +import org.apache.cassandra.net.ResourceLimits; import org.apache.cassandra.schema.Schema; import org.apache.cassandra.schema.SchemaChangeListener; import org.apache.cassandra.security.SSLFactory; @@ -83,7 +90,6 @@ public Connection newConnection(Channel channel, ProtocolVersion version) private final AtomicBoolean isRunning = new AtomicBoolean(false); private EventLoopGroup workerGroup; - private EventExecutor eventExecutorGroup; private Server (Builder builder) { @@ -100,8 +106,6 @@ private Server (Builder builder) else workerGroup = new NioEventLoopGroup(); } - if (builder.eventExecutorGroup != null) - eventExecutorGroup = builder.eventExecutorGroup; EventNotifier notifier = new EventNotifier(this); StorageService.instance.register(notifier); Schema.instance.registerListener(notifier); @@ -230,12 +234,6 @@ public Builder withEventLoopGroup(EventLoopGroup eventLoopGroup) return this; } - public Builder withEventExecutor(EventExecutor eventExecutor) - { - this.eventExecutorGroup = eventExecutor; - return this; - } - public Builder withHost(InetAddress host) { this.hostAddr = host; @@ -337,6 +335,49 @@ Map countConnectedClientsByUser() } + // global inflight payload across all channels across all endpoints + private static final ResourceLimits.Concurrent globalRequestPayloadInFlight = new ResourceLimits.Concurrent(DatabaseDescriptor.getNativeTransportMaxConcurrentRequestsInBytes()); + + public static class EndpointPayloadTracker + { + // inflight payload per endpoint across corresponding channels + private static final ConcurrentMap requestPayloadInFlightPerEndpoint = new ConcurrentHashMap<>(); + + private final AtomicInteger refCount = new AtomicInteger(0); + private final InetAddress endpoint; + + final ResourceLimits.EndpointAndGlobal endpointAndGlobalPayloadsInFlight = new ResourceLimits.EndpointAndGlobal(new ResourceLimits.Concurrent(DatabaseDescriptor.getNativeTransportMaxConcurrentRequestsInBytesPerIp()), + globalRequestPayloadInFlight); + + private EndpointPayloadTracker(InetAddress endpoint) + { + this.endpoint = endpoint; + } + + public static EndpointPayloadTracker get(InetAddress endpoint) + { + while (true) + { + EndpointPayloadTracker result = requestPayloadInFlightPerEndpoint.computeIfAbsent(endpoint, EndpointPayloadTracker::new); + if (result.acquire()) + return result; + + requestPayloadInFlightPerEndpoint.remove(endpoint, result); + } + } + + private boolean acquire() + { + return 0 < refCount.updateAndGet(i -> i < 0 ? i : i + 1); + } + + public void release() + { + if (-1 == refCount.updateAndGet(i -> i == 1 ? -1 : i - 1)) + requestPayloadInFlightPerEndpoint.remove(endpoint, this); + } + } + private static class Initializer extends ChannelInitializer { // Stateless handlers @@ -346,7 +387,6 @@ private static class Initializer extends ChannelInitializer private static final Frame.OutboundBodyTransformer outboundFrameTransformer = new Frame.OutboundBodyTransformer(); private static final Frame.Encoder frameEncoder = new Frame.Encoder(); private static final Message.ExceptionHandler exceptionHandler = new Message.ExceptionHandler(); - private static final Message.Dispatcher dispatcher = new Message.Dispatcher(DatabaseDescriptor.useNativeTransportLegacyFlusher()); private static final ConnectionLimitHandler connectionLimitHandler = new ConnectionLimitHandler(); private final Server server; @@ -368,6 +408,20 @@ protected void initChannel(Channel channel) throws Exception pipeline.addFirst("connectionLimitHandler", connectionLimitHandler); } + long idleTimeout = DatabaseDescriptor.nativeTransportIdleTimeout(); + if (idleTimeout > 0) + { + pipeline.addLast("idleStateHandler", new IdleStateHandler(false, 0, 0, idleTimeout, TimeUnit.MILLISECONDS) + { + @Override + protected void channelIdle(ChannelHandlerContext ctx, IdleStateEvent evt) + { + logger.info("Closing client connection {} after timeout of {}ms", channel.remoteAddress(), idleTimeout); + ctx.close(); + } + }); + } + //pipeline.addLast("debug", new LoggingHandler()); pipeline.addLast("frameDecoder", new Frame.Decoder(server.connectionFactory)); @@ -379,6 +433,9 @@ protected void initChannel(Channel channel) throws Exception pipeline.addLast("messageDecoder", messageDecoder); pipeline.addLast("messageEncoder", messageEncoder); + pipeline.addLast("executor", new Message.Dispatcher(DatabaseDescriptor.useNativeTransportLegacyFlusher(), + EndpointPayloadTracker.get(((InetSocketAddress) channel.remoteAddress()).getAddress()))); + // The exceptionHandler will take care of handling exceptionCaught(...) events while still running // on the same EventLoop as all previous added handlers in the pipeline. This is important as the used // eventExecutorGroup may not enforce strict ordering for channel events. @@ -386,11 +443,6 @@ protected void initChannel(Channel channel) throws Exception // correctly handled before the handler itself is removed. // See https://issues.apache.org/jira/browse/CASSANDRA-13649 pipeline.addLast("exceptionHandler", exceptionHandler); - - if (server.eventExecutorGroup != null) - pipeline.addLast(server.eventExecutorGroup, "executor", dispatcher); - else - pipeline.addLast("executor", dispatcher); } } diff --git a/src/java/org/apache/cassandra/transport/SimpleClient.java b/src/java/org/apache/cassandra/transport/SimpleClient.java index deba207d6afb..6340b69cb5a2 100644 --- a/src/java/org/apache/cassandra/transport/SimpleClient.java +++ b/src/java/org/apache/cassandra/transport/SimpleClient.java @@ -123,11 +123,19 @@ public SimpleClient(String host, int port) } public SimpleClient connect(boolean useCompression, boolean useChecksums) throws IOException + { + return connect(useCompression, useChecksums, false); + } + + public SimpleClient connect(boolean useCompression, boolean useChecksums, boolean throwOnOverload) throws IOException { establishConnection(); Map options = new HashMap<>(); options.put(StartupMessage.CQL_VERSION, "3.0.0"); + if (throwOnOverload) + options.put(StartupMessage.THROW_ON_OVERLOAD, "1"); + connection.setThrowOnOverload(throwOnOverload); if (useChecksums) { diff --git a/src/java/org/apache/cassandra/transport/messages/ErrorMessage.java b/src/java/org/apache/cassandra/transport/messages/ErrorMessage.java index 8e0b19ff64bc..a358015b3510 100644 --- a/src/java/org/apache/cassandra/transport/messages/ErrorMessage.java +++ b/src/java/org/apache/cassandra/transport/messages/ErrorMessage.java @@ -377,7 +377,7 @@ else if (e instanceof WrappedException) if (e instanceof ProtocolException) { // if the driver attempted to connect with a protocol version not supported then - // reply with the appropiate version, see ProtocolVersion.decode() + // respond with the appropiate version, see ProtocolVersion.decode() ProtocolVersion forcedProtocolVersion = ((ProtocolException) e).getForcedProtocolVersion(); if (forcedProtocolVersion != null) message.forcedProtocolVersion = forcedProtocolVersion; diff --git a/src/java/org/apache/cassandra/transport/messages/StartupMessage.java b/src/java/org/apache/cassandra/transport/messages/StartupMessage.java index ef846c1f9a87..ee2b34e8ae08 100644 --- a/src/java/org/apache/cassandra/transport/messages/StartupMessage.java +++ b/src/java/org/apache/cassandra/transport/messages/StartupMessage.java @@ -46,6 +46,7 @@ public class StartupMessage extends Message.Request public static final String DRIVER_NAME = "DRIVER_NAME"; public static final String DRIVER_VERSION = "DRIVER_VERSION"; public static final String CHECKSUM = "CONTENT_CHECKSUM"; + public static final String THROW_ON_OVERLOAD = "THROW_ON_OVERLOAD"; public static final Message.Codec codec = new Message.Codec() { @@ -104,6 +105,8 @@ else if (null != compressor) connection.setTransformer(CompressingTransformer.getTransformer(compressor)); } + connection.setThrowOnOverload("1".equals(options.get(THROW_ON_OVERLOAD))); + ClientState clientState = state.getClientState(); String driverName = options.get(DRIVER_NAME); if (null != driverName) diff --git a/src/java/org/apache/cassandra/utils/ApproximateTime.java b/src/java/org/apache/cassandra/utils/ApproximateTime.java new file mode 100644 index 000000000000..32b6e44317ca --- /dev/null +++ b/src/java/org/apache/cassandra/utils/ApproximateTime.java @@ -0,0 +1,192 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.utils; + +import java.util.concurrent.CancellationException; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; + +import com.google.common.annotations.VisibleForTesting; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.cassandra.concurrent.ScheduledExecutors; +import org.apache.cassandra.config.Config; + +import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static org.apache.cassandra.utils.ApproximateTime.Measurement.ALMOST_NOW; +import static org.apache.cassandra.utils.ApproximateTime.Measurement.ALMOST_SAME_TIME; + +/** + * This class provides approximate time utilities: + * - An imprecise nanoTime (monotonic) and currentTimeMillis (non-monotonic), that are faster than their regular counterparts + * They have a configured approximate precision (default of 10ms), which is the cadence they will be updated if the system is healthy + * - A mechanism for converting between nanoTime and currentTimeMillis measurements. + * These conversions may have drifted, and they offer no absolute guarantees on precision + */ +public class ApproximateTime +{ + private static final Logger logger = LoggerFactory.getLogger(ApproximateTime.class); + private static final int ALMOST_NOW_UPDATE_INTERVAL_MS = Math.max(1, Integer.parseInt(System.getProperty(Config.PROPERTY_PREFIX + "approximate_time_precision_ms", "2"))); + private static final String CONVERSION_UPDATE_INTERVAL_PROPERTY = Config.PROPERTY_PREFIX + "NANOTIMETOMILLIS_TIMESTAMP_UPDATE_INTERVAL"; + private static final long ALMOST_SAME_TIME_UPDATE_INTERVAL_MS = Long.getLong(CONVERSION_UPDATE_INTERVAL_PROPERTY, 10000); + + public static class AlmostSameTime + { + final long millis; + final long nanos; + final long error; // maximum error of millis measurement (in nanos) + + private AlmostSameTime(long millis, long nanos, long error) + { + this.millis = millis; + this.nanos = nanos; + this.error = error; + } + + public long toCurrentTimeMillis(long nanoTime) + { + return millis + TimeUnit.NANOSECONDS.toMillis(nanoTime - nanos); + } + + public long toNanoTime(long currentTimeMillis) + { + return nanos + MILLISECONDS.toNanos(currentTimeMillis - millis); + } + } + + public enum Measurement { ALMOST_NOW, ALMOST_SAME_TIME } + + private static volatile Future almostNowUpdater; + private static volatile Future almostSameTimeUpdater; + + private static volatile long almostNowMillis; + private static volatile long almostNowNanos; + + private static volatile AlmostSameTime almostSameTime = new AlmostSameTime(0L, 0L, Long.MAX_VALUE); + private static double failedAlmostSameTimeUpdateModifier = 1.0; + + private static final Runnable refreshAlmostNow = () -> { + almostNowMillis = System.currentTimeMillis(); + almostNowNanos = System.nanoTime(); + }; + + private static final Runnable refreshAlmostSameTime = () -> { + final int tries = 3; + long[] samples = new long[2 * tries + 1]; + samples[0] = System.nanoTime(); + for (int i = 1 ; i < samples.length ; i += 2) + { + samples[i] = System.currentTimeMillis(); + samples[i + 1] = System.nanoTime(); + } + + int best = 1; + // take sample with minimum delta between calls + for (int i = 3 ; i < samples.length - 1 ; i += 2) + { + if ((samples[i+1] - samples[i-1]) < (samples[best+1]-samples[best-1])) + best = i; + } + + long millis = samples[best]; + long nanos = (samples[best+1] / 2) + (samples[best-1] / 2); + long error = (samples[best+1] / 2) - (samples[best-1] / 2); + + AlmostSameTime prev = almostSameTime; + AlmostSameTime next = new AlmostSameTime(millis, nanos, error); + + if (next.error > prev.error && next.error > prev.error * failedAlmostSameTimeUpdateModifier) + { + failedAlmostSameTimeUpdateModifier *= 1.1; + return; + } + + failedAlmostSameTimeUpdateModifier = 1.0; + almostSameTime = next; + }; + + static + { + start(ALMOST_NOW); + start(ALMOST_SAME_TIME); + } + + public static synchronized void stop(Measurement measurement) + { + switch (measurement) + { + case ALMOST_NOW: + almostNowUpdater.cancel(true); + try { almostNowUpdater.get(); } catch (Throwable t) { } + almostNowUpdater = null; + break; + case ALMOST_SAME_TIME: + almostSameTimeUpdater.cancel(true); + try { almostSameTimeUpdater.get(); } catch (Throwable t) { } + almostSameTimeUpdater = null; + break; + } + } + + public static synchronized void start(Measurement measurement) + { + switch (measurement) + { + case ALMOST_NOW: + if (almostNowUpdater != null) + throw new IllegalStateException("Already running"); + refreshAlmostNow.run(); + logger.info("Scheduling approximate time-check task with a precision of {} milliseconds", ALMOST_NOW_UPDATE_INTERVAL_MS); + almostNowUpdater = ScheduledExecutors.scheduledFastTasks.scheduleWithFixedDelay(refreshAlmostNow, ALMOST_NOW_UPDATE_INTERVAL_MS, ALMOST_NOW_UPDATE_INTERVAL_MS, MILLISECONDS); + break; + case ALMOST_SAME_TIME: + if (almostSameTimeUpdater != null) + throw new IllegalStateException("Already running"); + refreshAlmostSameTime.run(); + logger.info("Scheduling approximate time conversion task with an interval of {} milliseconds", ALMOST_SAME_TIME_UPDATE_INTERVAL_MS); + almostSameTimeUpdater = ScheduledExecutors.scheduledFastTasks.scheduleWithFixedDelay(refreshAlmostSameTime, ALMOST_SAME_TIME_UPDATE_INTERVAL_MS, ALMOST_SAME_TIME_UPDATE_INTERVAL_MS, MILLISECONDS); + break; + } + } + + + /** + * Request an immediate refresh; this shouldn't generally be invoked, except perhaps by tests + */ + @VisibleForTesting + public static synchronized void refresh(Measurement measurement) + { + stop(measurement); + start(measurement); + } + + /** no guarantees about relationship to nanoTime; non-monotonic (tracks currentTimeMillis as closely as possible) */ + public static long currentTimeMillis() + { + return almostNowMillis; + } + + /** no guarantees about relationship to currentTimeMillis; monotonic */ + public static long nanoTime() + { + return almostNowNanos; + } +} diff --git a/src/java/org/apache/cassandra/utils/ByteBufferUtil.java b/src/java/org/apache/cassandra/utils/ByteBufferUtil.java index d6c9e52b7793..518436ea2527 100644 --- a/src/java/org/apache/cassandra/utils/ByteBufferUtil.java +++ b/src/java/org/apache/cassandra/utils/ByteBufferUtil.java @@ -101,6 +101,16 @@ public static int compare(ByteBuffer o1, byte[] o2) return FastByteOperations.compareUnsigned(o1, o2, 0, o2.length); } + public static int compare(ByteBuffer o1, int s1, int l1, byte[] o2) + { + return FastByteOperations.compareUnsigned(o1, s1, l1, o2, 0, o2.length); + } + + public static int compare(byte[] o1, ByteBuffer o2, int s2, int l2) + { + return FastByteOperations.compareUnsigned(o1, 0, o1.length, o2, s2, l2); + } + /** * Decode a String representation. * This method assumes that the encoding charset is UTF_8. @@ -161,16 +171,25 @@ public static String string(ByteBuffer buffer, Charset charset) throws Character */ public static byte[] getArray(ByteBuffer buffer) { - int length = buffer.remaining(); + return getArray(buffer, buffer.position(), buffer.remaining()); + } + + /** + * You should almost never use this. Instead, use the write* methods to avoid copies. + */ + public static byte[] getArray(ByteBuffer buffer, int position, int length) + { if (buffer.hasArray()) { - int boff = buffer.arrayOffset() + buffer.position(); + int boff = buffer.arrayOffset() + position; return Arrays.copyOfRange(buffer.array(), boff, boff + length); } + // else, DirectByteBuffer.get() is the fastest route byte[] bytes = new byte[length]; - buffer.duplicate().get(bytes); - + ByteBuffer dup = buffer.duplicate(); + dup.position(position).limit(position + length); + dup.get(bytes); return bytes; } @@ -255,14 +274,14 @@ public static ByteBuffer clone(ByteBuffer buffer) return clone; } - public static void arrayCopy(ByteBuffer src, int srcPos, byte[] dst, int dstPos, int length) + public static void copyBytes(ByteBuffer src, int srcPos, byte[] dst, int dstPos, int length) { FastByteOperations.copy(src, srcPos, dst, dstPos, length); } /** * Transfer bytes from one ByteBuffer to another. - * This function acts as System.arrayCopy() but for ByteBuffers. + * This function acts as System.arrayCopy() but for ByteBuffers, and operates safely on direct memory. * * @param src the source ByteBuffer * @param srcPos starting position in the source ByteBuffer @@ -270,7 +289,7 @@ public static void arrayCopy(ByteBuffer src, int srcPos, byte[] dst, int dstPos, * @param dstPos starting position in the destination ByteBuffer * @param length the number of bytes to copy */ - public static void arrayCopy(ByteBuffer src, int srcPos, ByteBuffer dst, int dstPos, int length) + public static void copyBytes(ByteBuffer src, int srcPos, ByteBuffer dst, int dstPos, int length) { FastByteOperations.copy(src, srcPos, dst, dstPos, length); } @@ -278,7 +297,7 @@ public static void arrayCopy(ByteBuffer src, int srcPos, ByteBuffer dst, int dst public static int put(ByteBuffer src, ByteBuffer trg) { int length = Math.min(src.remaining(), trg.remaining()); - arrayCopy(src, src.position(), trg, trg.position(), length); + copyBytes(src, src.position(), trg, trg.position(), length); trg.position(trg.position() + length); src.position(src.position() + length); return length; @@ -631,6 +650,7 @@ public static int compareSubArrays(ByteBuffer bytes1, int offset1, ByteBuffer by assert bytes1.limit() >= offset1 + length : "The first byte array isn't long enough for the specified offset and length."; assert bytes2.limit() >= offset2 + length : "The second byte array isn't long enough for the specified offset and length."; + for (int i = 0; i < length; i++) { byte byte1 = bytes1.get(offset1 + i); @@ -669,7 +689,7 @@ public static ByteBuffer minimalBufferFor(ByteBuffer buf) return buf.capacity() > buf.remaining() || !buf.hasArray() ? ByteBuffer.wrap(getArray(buf)) : buf; } - // Doesn't change bb position + // doesn't change bb position public static int getShortLength(ByteBuffer bb, int position) { int length = (bb.get(position) & 0xFF) << 8; diff --git a/src/java/org/apache/cassandra/utils/Clock.java b/src/java/org/apache/cassandra/utils/Clock.java deleted file mode 100644 index eb9822c1cff9..000000000000 --- a/src/java/org/apache/cassandra/utils/Clock.java +++ /dev/null @@ -1,80 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.cassandra.utils; - -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -/** - * Wrapper around time related functions that are either implemented by using the default JVM calls - * or by using a custom implementation for testing purposes. - * - * See {@link #instance} for how to use a custom implementation. - * - * Please note that {@link java.time.Clock} wasn't used, as it would not be possible to provide an - * implementation for {@link #nanoTime()} with the exact same properties of {@link System#nanoTime()}. - */ -public class Clock -{ - private static final Logger logger = LoggerFactory.getLogger(Clock.class); - - /** - * Static singleton object that will be instanciated by default with a system clock - * implementation. Set cassandra.clock system property to a FQCN to use a - * different implementation instead. - */ - public static Clock instance; - - static - { - String sclock = System.getProperty("cassandra.clock"); - if (sclock == null) - { - instance = new Clock(); - } - else - { - try - { - logger.debug("Using custom clock implementation: {}", sclock); - instance = (Clock) Class.forName(sclock).newInstance(); - } - catch (Exception e) - { - logger.error(e.getMessage(), e); - } - } - } - - /** - * @see System#nanoTime() - */ - public long nanoTime() - { - return System.nanoTime(); - } - - /** - * @see System#currentTimeMillis() - */ - public long currentTimeMillis() - { - return System.currentTimeMillis(); - } - -} diff --git a/src/java/org/apache/cassandra/utils/CoalescingStrategies.java b/src/java/org/apache/cassandra/utils/CoalescingStrategies.java deleted file mode 100644 index 2f9e5bbf084d..000000000000 --- a/src/java/org/apache/cassandra/utils/CoalescingStrategies.java +++ /dev/null @@ -1,444 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.cassandra.utils; - -import org.apache.cassandra.concurrent.ScheduledExecutors; -import org.apache.cassandra.config.Config; -import org.apache.cassandra.io.util.FileUtils; -import org.slf4j.Logger; - -import java.io.File; -import java.io.RandomAccessFile; -import java.lang.reflect.Constructor; -import java.nio.ByteBuffer; -import java.nio.channels.FileChannel.MapMode; -import java.util.Arrays; -import java.util.Collection; -import java.util.Optional; -import java.util.concurrent.TimeUnit; -import java.util.Locale; - -import com.google.common.annotations.VisibleForTesting; - -/** - * Groups strategies to coalesce messages. - */ -public class CoalescingStrategies -{ - /* - * Log debug information at info level about what the average is and when coalescing is enabled/disabled - */ - private static final String DEBUG_COALESCING_PROPERTY = Config.PROPERTY_PREFIX + "coalescing_debug"; - private static final boolean DEBUG_COALESCING = Boolean.getBoolean(DEBUG_COALESCING_PROPERTY); - - private static final String DEBUG_COALESCING_PATH_PROPERTY = Config.PROPERTY_PREFIX + "coalescing_debug_path"; - private static final String DEBUG_COALESCING_PATH = System.getProperty(DEBUG_COALESCING_PATH_PROPERTY, "/tmp/coleascing_debug"); - - public enum Strategy { MOVINGAVERAGE, FIXED, TIMEHORIZON, DISABLED } - - static - { - if (DEBUG_COALESCING) - { - File directory = new File(DEBUG_COALESCING_PATH); - - if (directory.exists()) - FileUtils.deleteRecursive(directory); - - if (!directory.mkdirs()) - throw new ExceptionInInitializerError("Couldn't create log dir"); - } - } - - public static interface Coalescable - { - long timestampNanos(); - } - - @VisibleForTesting - static long determineCoalescingTime(long averageGap, long maxCoalesceWindow) - { - // Don't bother waiting at all if we're unlikely to get any new message within our max window - if (averageGap > maxCoalesceWindow) - return -1; - - // avoid the degenerate case of zero (very unlikely, but let's be safe) - if (averageGap <= 0) - return maxCoalesceWindow; - - // assume we receive as many messages as we expect; apply the same logic to the future batch: - // expect twice as many messages to consider sleeping for "another" interval; this basically translates - // to doubling our sleep period until we exceed our max sleep window. - long sleep = averageGap; - while (sleep * 2 < maxCoalesceWindow) - sleep *= 2; - return sleep; - } - - /** - * A coalescing strategy, that decides when to coalesce messages. - *

- * The general principle is that, when asked, the strategy returns the time delay we want to wait for more messages - * to arrive before sending so message can be coalesced. For that, the strategy must be fed new messages through - * the {@link #newArrival(Coalescable)} method (the only assumption we make on messages is that they have an associated - * timestamp). The strategy can then be queried for the time to wait for coalescing through - * {@link #currentCoalescingTimeNanos()}. - *

- * Note that it is expected that a call {@link #currentCoalescingTimeNanos()} will come just after a call to - * {@link #newArrival(Coalescable))}, as the intent of the value returned by the former method is "Given a new message, how much - * time should I wait for more messages to arrive and be coalesced with that message". But both calls are separated - * as one may not want to call {@link #currentCoalescingTimeNanos()} after every call to {@link #newArrival(Coalescable)} - * and we thus save processing. How arrivals influence the coalescing time is however entirely up to the strategy and some - * strategy may ignore arrivals completely and return a constant coalescing time. - */ - public interface CoalescingStrategy - { - /** - * Inform the strategy of a new message to consider. - * - * @param message the message to consider. - */ - void newArrival(Coalescable message); - - /** - * The current time to wait for the purpose of coalescing messages. - * - * @return the coalescing time. A negative value can be returned if no coalescing should be done (which can be a - * transient thing). - */ - long currentCoalescingTimeNanos(); - } - - public static abstract class AbstractCoalescingStrategy implements CoalescingStrategy - { - protected final Logger logger; - protected volatile boolean shouldLogAverage = false; - protected final ByteBuffer logBuffer; - private RandomAccessFile ras; - private final String displayName; - - protected AbstractCoalescingStrategy(Logger logger, String displayName) - { - this.logger = logger; - this.displayName = displayName; - - RandomAccessFile rasTemp = null; - ByteBuffer logBufferTemp = null; - if (DEBUG_COALESCING) - { - ScheduledExecutors.scheduledFastTasks.scheduleWithFixedDelay(() -> shouldLogAverage = true, 5, 5, TimeUnit.SECONDS); - try - { - File outFile = FileUtils.createTempFile("coalescing_" + this.displayName + "_", ".log", new File(DEBUG_COALESCING_PATH)); - rasTemp = new RandomAccessFile(outFile, "rw"); - logBufferTemp = ras.getChannel().map(MapMode.READ_WRITE, 0, Integer.MAX_VALUE); - logBufferTemp.putLong(0); - } - catch (Exception e) - { - logger.error("Unable to create output file for debugging coalescing", e); - } - } - ras = rasTemp; - logBuffer = logBufferTemp; - } - - /* - * If debugging is enabled log to the logger the current average gap calculation result. - */ - final protected void debugGap(long averageGap) - { - if (DEBUG_COALESCING && shouldLogAverage) - { - shouldLogAverage = false; - logger.info("{} gap {}μs", this, TimeUnit.NANOSECONDS.toMicros(averageGap)); - } - } - - /* - * If debugging is enabled log the provided nanotime timestamp to a file. - */ - final protected void debugTimestamp(long timestamp) - { - if(DEBUG_COALESCING && logBuffer != null) - { - logBuffer.putLong(0, logBuffer.getLong(0) + 1); - logBuffer.putLong(timestamp); - } - } - - /* - * If debugging is enabled log the timestamps of all the items in the provided collection - * to a file. - */ - final protected void debugTimestamps(Collection coalescables) - { - if (DEBUG_COALESCING) - { - for (C coalescable : coalescables) - { - debugTimestamp(coalescable.timestampNanos()); - } - } - } - } - - @VisibleForTesting - static class TimeHorizonMovingAverageCoalescingStrategy extends AbstractCoalescingStrategy - { - // for now we'll just use 64ms per bucket; this can be made configurable, but results in ~1s for 16 samples - private static final int INDEX_SHIFT = 26; - private static final long BUCKET_INTERVAL = 1L << 26; - private static final int BUCKET_COUNT = 16; - private static final long INTERVAL = BUCKET_INTERVAL * BUCKET_COUNT; - private static final long MEASURED_INTERVAL = BUCKET_INTERVAL * (BUCKET_COUNT - 1); - - // the minimum timestamp we will now accept updates for; only moves forwards, never backwards - private long epoch; - // the buckets, each following on from epoch; the measurements run from ix(epoch) to ix(epoch - 1) - // ix(epoch-1) is a partial result, that is never actually part of the calculation, and most updates - // are expected to hit this bucket - private final int samples[] = new int[BUCKET_COUNT]; - private long sum = 0; - private final long maxCoalesceWindow; - - public TimeHorizonMovingAverageCoalescingStrategy(int maxCoalesceWindow, Logger logger, String displayName, long initialEpoch) - { - super(logger, displayName); - this.maxCoalesceWindow = TimeUnit.MICROSECONDS.toNanos(maxCoalesceWindow); - sum = 0; - epoch = initialEpoch; - } - - private long averageGap() - { - if (sum == 0) - return Integer.MAX_VALUE; - return MEASURED_INTERVAL / sum; - } - - // this sample extends past the end of the range we cover, so rollover - private long rollEpoch(long delta, long epoch, long nanos) - { - if (delta > 2 * INTERVAL) - { - // this sample is more than twice our interval ahead, so just clear our counters completely - epoch = epoch(nanos); - sum = 0; - Arrays.fill(samples, 0); - } - else - { - // ix(epoch - 1) => last index; this is our partial result bucket, so we add this to the sum - sum += samples[ix(epoch - 1)]; - // then we roll forwards, clearing buckets, until our interval covers the new sample time - while (epoch + INTERVAL < nanos) - { - int index = ix(epoch); - sum -= samples[index]; - samples[index] = 0; - epoch += BUCKET_INTERVAL; - } - } - // store the new epoch - this.epoch = epoch; - return epoch; - } - - private long epoch(long latestNanos) - { - return (latestNanos - MEASURED_INTERVAL) & ~(BUCKET_INTERVAL - 1); - } - - private int ix(long nanos) - { - return (int) ((nanos >>> INDEX_SHIFT) & 15); - } - - public void newArrival(Coalescable message) - { - final long timestamp = message.timestampNanos(); - debugTimestamp(timestamp); - long epoch = this.epoch; - long delta = timestamp - epoch; - if (delta < 0) - // have to simply ignore, but would be a bit unlucky to get such reordering - return; - - if (delta > INTERVAL) - epoch = rollEpoch(delta, epoch, timestamp); - - int ix = ix(timestamp); - samples[ix]++; - - // if we've updated an old bucket, we need to update the sum to match - if (ix != ix(epoch - 1)) - sum++; - } - - public long currentCoalescingTimeNanos() - { - long averageGap = averageGap(); - debugGap(averageGap); - return determineCoalescingTime(averageGap, maxCoalesceWindow); - } - - @Override - public String toString() - { - return "Time horizon moving average"; - } - } - - /** - * Start coalescing by sleeping if the moving average is < the requested window. - * The actual time spent waiting to coalesce will be the min( window, moving average * 2) - * The actual amount of time spent waiting can be greater then the window. For instance - * observed time spent coalescing was 400 microseconds with the window set to 200 in one benchmark. - */ - @VisibleForTesting - static class MovingAverageCoalescingStrategy extends AbstractCoalescingStrategy - { - static final int SAMPLE_SIZE = 16; - private final int samples[] = new int[SAMPLE_SIZE]; - private final long maxCoalesceWindow; - - private long lastSample = 0; - private int index = 0; - private long sum = 0; - private long currentGap; - - public MovingAverageCoalescingStrategy(int maxCoalesceWindow, Logger logger, String displayName) - { - super(logger, displayName); - this.maxCoalesceWindow = TimeUnit.MICROSECONDS.toNanos(maxCoalesceWindow); - for (int ii = 0; ii < samples.length; ii++) - samples[ii] = Integer.MAX_VALUE; - sum = Integer.MAX_VALUE * (long)samples.length; - } - - private long logSample(int value) - { - sum -= samples[index]; - sum += value; - samples[index] = value; - index++; - index = index & ((1 << 4) - 1); - return sum / SAMPLE_SIZE; - } - - public void newArrival(Coalescable message) - { - final long timestamp = message.timestampNanos(); - debugTimestamp(timestamp); - if (timestamp > lastSample) - { - final int delta = (int)(Math.min(Integer.MAX_VALUE, timestamp - lastSample)); - lastSample = timestamp; - currentGap = logSample(delta); - } - else - { - currentGap = logSample(1); - } - } - - public long currentCoalescingTimeNanos() - { - debugGap(currentGap); - return determineCoalescingTime(currentGap, maxCoalesceWindow); - } - - @Override - public String toString() - { - return "Moving average"; - } - } - - /** - * A fixed strategy as a backup in case MovingAverage or TimeHorizongMovingAverage fails in some scenario - */ - @VisibleForTesting - static class FixedCoalescingStrategy extends AbstractCoalescingStrategy - { - private final long coalesceWindow; - - public FixedCoalescingStrategy(int coalesceWindowMicros, Logger logger, String displayName) - { - super(logger, displayName); - coalesceWindow = TimeUnit.MICROSECONDS.toNanos(coalesceWindowMicros); - } - - public void newArrival(Coalescable message) - { - debugTimestamp(message.timestampNanos()); - } - - public long currentCoalescingTimeNanos() - { - return coalesceWindow; - } - - @Override - public String toString() - { - return "Fixed"; - } - } - - public static Optional newCoalescingStrategy(String strategy, int coalesceWindow, Logger logger, String displayName) - { - String strategyCleaned = strategy.trim().toUpperCase(Locale.ENGLISH); - - try - { - switch (Enum.valueOf(Strategy.class, strategyCleaned)) - { - case MOVINGAVERAGE: - return Optional.of(new MovingAverageCoalescingStrategy(coalesceWindow, logger, displayName)); - case FIXED: - return Optional.of(new FixedCoalescingStrategy(coalesceWindow, logger, displayName)); - case TIMEHORIZON: - long initialEpoch = System.nanoTime(); - return Optional.of(new TimeHorizonMovingAverageCoalescingStrategy(coalesceWindow, logger, displayName, initialEpoch)); - case DISABLED: - return Optional.empty(); - default: - throw new IllegalArgumentException("supported coalese strategy"); - } - } - catch (IllegalArgumentException iae) - { - try - { - Class clazz = Class.forName(strategy); - - if (!CoalescingStrategy.class.isAssignableFrom(clazz)) - throw new RuntimeException(strategy + " is not an instance of CoalescingStrategy"); - - Constructor constructor = clazz.getConstructor(int.class, Logger.class, String.class); - return Optional.of((CoalescingStrategy)constructor.newInstance(coalesceWindow, logger, displayName)); - } - catch (Exception e) - { - throw new RuntimeException(e); - } - } - } -} diff --git a/src/java/org/apache/cassandra/utils/ExecutorUtils.java b/src/java/org/apache/cassandra/utils/ExecutorUtils.java new file mode 100644 index 000000000000..21933a312360 --- /dev/null +++ b/src/java/org/apache/cassandra/utils/ExecutorUtils.java @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.utils; + +import java.util.Arrays; +import java.util.Collection; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +import org.apache.cassandra.concurrent.InfiniteLoopExecutor; + +import static java.util.concurrent.TimeUnit.NANOSECONDS; + +public class ExecutorUtils +{ + + public static Runnable runWithThreadName(Runnable runnable, String threadName) + { + return () -> { + String oldThreadName = Thread.currentThread().getName(); + try + { + Thread.currentThread().setName(threadName); + runnable.run(); + } + finally + { + Thread.currentThread().setName(oldThreadName); + } + }; + } + + public static void shutdownNow(Iterable executors) + { + shutdown(true, executors); + } + + public static void shutdown(Iterable executors) + { + shutdown(false, executors); + } + + public static void shutdown(boolean interrupt, Iterable executors) + { + for (Object executor : executors) + { + if (executor instanceof ExecutorService) + { + if (interrupt) ((ExecutorService) executor).shutdownNow(); + else ((ExecutorService) executor).shutdown(); + } + else if (executor instanceof InfiniteLoopExecutor) + ((InfiniteLoopExecutor) executor).shutdownNow(); + else if (executor instanceof Thread) + ((Thread) executor).interrupt(); + else if (executor != null) + throw new IllegalArgumentException(executor.toString()); + } + } + + public static void shutdown(ExecutorService ... executors) + { + shutdown(Arrays.asList(executors)); + } + + public static void shutdownNow(ExecutorService ... executors) + { + shutdownNow(Arrays.asList(executors)); + } + + public static void awaitTermination(long timeout, TimeUnit unit, ExecutorService ... executors) throws InterruptedException, TimeoutException + { + awaitTermination(timeout, unit, Arrays.asList(executors)); + } + + public static void awaitTermination(long timeout, TimeUnit unit, Collection executors) throws InterruptedException, TimeoutException + { + long deadline = System.nanoTime() + unit.toNanos(timeout); + awaitTerminationUntil(deadline, executors); + } + + public static void awaitTerminationUntil(long deadline, Collection executors) throws InterruptedException, TimeoutException + { + for (Object executor : executors) + { + long wait = deadline - System.nanoTime(); + if (executor instanceof ExecutorService) + { + if (wait <= 0 || !((ExecutorService)executor).awaitTermination(wait, NANOSECONDS)) + throw new TimeoutException(executor + " did not terminate on time"); + } + else if (executor instanceof InfiniteLoopExecutor) + { + if (wait <= 0 || !((InfiniteLoopExecutor)executor).awaitTermination(wait, NANOSECONDS)) + throw new TimeoutException(executor + " did not terminate on time"); + } + else if (executor instanceof Thread) + { + Thread t = (Thread) executor; + if (wait <= 0) + throw new TimeoutException(executor + " did not terminate on time"); + t.join((wait + 999999) / 1000000L, (int) (wait % 1000000L)); + if (t.isAlive()) + throw new TimeoutException(executor + " did not terminate on time"); + } + else if (executor != null) + { + throw new IllegalArgumentException(executor.toString()); + } + } + } + + public static void shutdownAndWait(long timeout, TimeUnit unit, Collection executors) throws TimeoutException, InterruptedException + { + shutdown(executors); + awaitTermination(timeout, unit, executors); + } + + public static void shutdownNowAndWait(long timeout, TimeUnit unit, Collection executors) throws TimeoutException, InterruptedException + { + shutdownNow(executors); + awaitTermination(timeout, unit, executors); + } + + public static void shutdownAndWait(long timeout, TimeUnit unit, Object ... executors) throws TimeoutException, InterruptedException + { + shutdownAndWait(timeout, unit, Arrays.asList(executors)); + } + + public static void shutdownNowAndWait(long timeout, TimeUnit unit, Object ... executors) throws TimeoutException, InterruptedException + { + shutdownNowAndWait(timeout, unit, Arrays.asList(executors)); + } +} \ No newline at end of file diff --git a/src/java/org/apache/cassandra/utils/ExpiringMap.java b/src/java/org/apache/cassandra/utils/ExpiringMap.java deleted file mode 100644 index ef013f57fdd6..000000000000 --- a/src/java/org/apache/cassandra/utils/ExpiringMap.java +++ /dev/null @@ -1,186 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.cassandra.utils; - -import java.util.Map; -import java.util.Set; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.ConcurrentMap; -import java.util.concurrent.ScheduledExecutorService; -import java.util.concurrent.TimeUnit; - -import com.google.common.base.Function; -import com.google.common.util.concurrent.Uninterruptibles; - -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import org.apache.cassandra.concurrent.DebuggableScheduledThreadPoolExecutor; - -public class ExpiringMap -{ - private static final Logger logger = LoggerFactory.getLogger(ExpiringMap.class); - private volatile boolean shutdown; - - public static class CacheableObject - { - public final T value; - public final long timeout; - private final long createdAt; - - private CacheableObject(T value, long timeout) - { - assert value != null; - this.value = value; - this.timeout = timeout; - this.createdAt = Clock.instance.nanoTime(); - } - - private boolean isReadyToDieAt(long atNano) - { - return atNano - createdAt > TimeUnit.MILLISECONDS.toNanos(timeout); - } - } - - // if we use more ExpiringMaps we may want to add multiple threads to this executor - private static final ScheduledExecutorService service = new DebuggableScheduledThreadPoolExecutor("EXPIRING-MAP-REAPER"); - - private final ConcurrentMap> cache = new ConcurrentHashMap>(); - private final long defaultExpiration; - - public ExpiringMap(long defaultExpiration) - { - this(defaultExpiration, null); - } - - /** - * - * @param defaultExpiration the TTL for objects in the cache in milliseconds - */ - public ExpiringMap(long defaultExpiration, final Function>, ?> postExpireHook) - { - this.defaultExpiration = defaultExpiration; - - if (defaultExpiration <= 0) - { - throw new IllegalArgumentException("Argument specified must be a positive number"); - } - - Runnable runnable = new Runnable() - { - public void run() - { - long start = Clock.instance.nanoTime(); - int n = 0; - for (Map.Entry> entry : cache.entrySet()) - { - if (entry.getValue().isReadyToDieAt(start)) - { - if (cache.remove(entry.getKey()) != null) - { - n++; - if (postExpireHook != null) - postExpireHook.apply(Pair.create(entry.getKey(), entry.getValue())); - } - } - } - logger.trace("Expired {} entries", n); - } - }; - service.scheduleWithFixedDelay(runnable, defaultExpiration / 2, defaultExpiration / 2, TimeUnit.MILLISECONDS); - } - - public boolean shutdownBlocking() - { - service.shutdown(); - try - { - return service.awaitTermination(defaultExpiration * 2, TimeUnit.MILLISECONDS); - } - catch (InterruptedException e) - { - throw new AssertionError(e); - } - } - - public void reset() - { - shutdown = false; - cache.clear(); - } - - public V put(K key, V value) - { - return put(key, value, this.defaultExpiration); - } - - public V put(K key, V value, long timeout) - { - if (shutdown) - { - // StorageProxy isn't equipped to deal with "I'm nominally alive, but I can't send any messages out." - // So we'll just sit on this thread until the rest of the server shutdown completes. - // - // See comments in CustomTThreadPoolServer.serve, CASSANDRA-3335, and CASSANDRA-3727. - Uninterruptibles.sleepUninterruptibly(Long.MAX_VALUE, TimeUnit.NANOSECONDS); - } - CacheableObject previous = cache.put(key, new CacheableObject(value, timeout)); - return (previous == null) ? null : previous.value; - } - - public V get(K key) - { - CacheableObject co = cache.get(key); - return co == null ? null : co.value; - } - - public V remove(K key) - { - CacheableObject co = cache.remove(key); - return co == null ? null : co.value; - } - - /** - * @return System.nanoTime() when key was put into the map. - */ - public long getAge(K key) - { - CacheableObject co = cache.get(key); - return co == null ? 0 : co.createdAt; - } - - public int size() - { - return cache.size(); - } - - public boolean containsKey(K key) - { - return cache.containsKey(key); - } - - public boolean isEmpty() - { - return cache.isEmpty(); - } - - public Set keySet() - { - return cache.keySet(); - } -} diff --git a/src/java/org/apache/cassandra/utils/FBUtilities.java b/src/java/org/apache/cassandra/utils/FBUtilities.java index 129c0f56885b..f0d913246b57 100644 --- a/src/java/org/apache/cassandra/utils/FBUtilities.java +++ b/src/java/org/apache/cassandra/utils/FBUtilities.java @@ -265,30 +265,6 @@ public static int compareUnsigned(byte[] bytes1, byte[] bytes2) return compareUnsigned(bytes1, bytes2, 0, 0, bytes1.length, bytes2.length); } - /** - * @return The bitwise XOR of the inputs. The output will be the same length as the - * longer input, but if either input is null, the output will be null. - */ - public static byte[] xor(byte[] left, byte[] right) - { - if (left == null || right == null) - return null; - if (left.length > right.length) - { - byte[] swap = left; - left = right; - right = swap; - } - - // left.length is now <= right.length - byte[] out = Arrays.copyOf(right, right.length); - for (int i = 0; i < left.length; i++) - { - out[i] = (byte)((left[i] & 0xFF) ^ (right[i] & 0xFF)); - } - return out; - } - public static void sortSampledKeys(List keys, Range range) { if (range.left.compareTo(range.right) >= 0) @@ -449,12 +425,6 @@ public static T waitOnFuture(Future future) } } - public static void waitOnFutures(List results, long ms) throws TimeoutException - { - for (AsyncOneResponse result : results) - result.get(ms, TimeUnit.MILLISECONDS); - } - public static Future waitOnFirstFuture(Iterable> futures) { return waitOnFirstFuture(futures, 100); diff --git a/src/java/org/apache/cassandra/utils/FastByteOperations.java b/src/java/org/apache/cassandra/utils/FastByteOperations.java index 6581736c5d0b..060dee59ad4a 100644 --- a/src/java/org/apache/cassandra/utils/FastByteOperations.java +++ b/src/java/org/apache/cassandra/utils/FastByteOperations.java @@ -55,6 +55,16 @@ public static int compareUnsigned(byte[] b1, int s1, int l1, ByteBuffer b2) return -BestHolder.BEST.compare(b2, b1, s1, l1); } + public static int compareUnsigned(ByteBuffer b1, int s1, int l1, byte[] b2, int s2, int l2) + { + return BestHolder.BEST.compare(b1, s1, l1, b2, s2, l2); + } + + public static int compareUnsigned(byte[] b1, int s1, int l1, ByteBuffer b2, int s2, int l2) + { + return -BestHolder.BEST.compare(b2, s2, l2, b1, s1, l1); + } + public static int compareUnsigned(ByteBuffer b1, ByteBuffer b2) { return BestHolder.BEST.compare(b1, b2); @@ -77,6 +87,8 @@ abstract public int compare(byte[] buffer1, int offset1, int length1, abstract public int compare(ByteBuffer buffer1, byte[] buffer2, int offset2, int length2); + abstract public int compare(ByteBuffer buffer1, int offset1, int length1, byte[] buffer2, int offset2, int length2); + abstract public int compare(ByteBuffer buffer1, ByteBuffer buffer2); abstract public void copy(ByteBuffer src, int srcPosition, byte[] trg, int trgPosition, int length); @@ -186,26 +198,25 @@ public int compare(byte[] buffer1, int offset1, int length1, byte[] buffer2, int } public int compare(ByteBuffer buffer1, byte[] buffer2, int offset2, int length2) + { + return compare(buffer1, buffer1.position(), buffer1.remaining(), buffer2, offset2, length2); + } + + public int compare(ByteBuffer buffer1, int position1, int length1, byte[] buffer2, int offset2, int length2) { Object obj1; long offset1; if (buffer1.hasArray()) { obj1 = buffer1.array(); - offset1 = BYTE_ARRAY_BASE_OFFSET + buffer1.arrayOffset(); + offset1 = BYTE_ARRAY_BASE_OFFSET + buffer1.arrayOffset() + position1; } else { obj1 = null; - offset1 = theUnsafe.getLong(buffer1, DIRECT_BUFFER_ADDRESS_OFFSET); - } - int length1; - { - int position = buffer1.position(); - int limit = buffer1.limit(); - length1 = limit - position; - offset1 += position; + offset1 = theUnsafe.getLong(buffer1, DIRECT_BUFFER_ADDRESS_OFFSET) + position1; } + return compareTo(obj1, offset1, length1, buffer2, BYTE_ARRAY_BASE_OFFSET + offset2, length2); } @@ -397,11 +408,28 @@ public int compare(byte[] buffer1, int offset1, int length1, return length1 - length2; } + public int compare(ByteBuffer buffer1, int position1, int length1, byte[] buffer2, int offset2, int length2) + { + if (buffer1.hasArray()) + return compare(buffer1.array(), buffer1.arrayOffset() + position1, length1, buffer2, offset2, length2); + + if (position1 != buffer1.position()) + { + buffer1 = buffer1.duplicate(); + buffer1.position(position1); + } + + return compare(buffer1, ByteBuffer.wrap(buffer2, offset2, length2)); + } + public int compare(ByteBuffer buffer1, byte[] buffer2, int offset2, int length2) { if (buffer1.hasArray()) + { return compare(buffer1.array(), buffer1.arrayOffset() + buffer1.position(), buffer1.remaining(), buffer2, offset2, length2); + } + return compare(buffer1, ByteBuffer.wrap(buffer2, offset2, length2)); } diff --git a/src/java/org/apache/cassandra/utils/JVMStabilityInspector.java b/src/java/org/apache/cassandra/utils/JVMStabilityInspector.java index e058ae2bcce6..cdcbf4bdd545 100644 --- a/src/java/org/apache/cassandra/utils/JVMStabilityInspector.java +++ b/src/java/org/apache/cassandra/utils/JVMStabilityInspector.java @@ -53,7 +53,12 @@ private JVMStabilityInspector() {} * @param t * The Throwable to check for server-stop conditions */ - public static void inspectThrowable(Throwable t) + public static void inspectThrowable(Throwable t) throws OutOfMemoryError + { + inspectThrowable(t, true); + } + + public static void inspectThrowable(Throwable t, boolean propagateOutOfMemory) throws OutOfMemoryError { boolean isUnstable = false; if (t instanceof OutOfMemoryError) @@ -76,6 +81,9 @@ public static void inspectThrowable(Throwable t) StorageService.instance.removeShutdownHook(); // We let the JVM handle the error. The startup checks should have warned the user if it did not configure // the JVM behavior in case of OOM (CASSANDRA-13006). + if (!propagateOutOfMemory) + return; + throw (OutOfMemoryError) t; } diff --git a/src/java/org/apache/cassandra/utils/MerkleTree.java b/src/java/org/apache/cassandra/utils/MerkleTree.java index 1d51f03b11a8..1b9255511d08 100644 --- a/src/java/org/apache/cassandra/utils/MerkleTree.java +++ b/src/java/org/apache/cassandra/utils/MerkleTree.java @@ -19,25 +19,35 @@ import java.io.DataInput; import java.io.IOException; -import java.io.Serializable; +import java.nio.ByteBuffer; import java.util.*; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; import com.google.common.collect.PeekingIterator; +import com.google.common.primitives.Ints; +import com.google.common.primitives.Shorts; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import org.apache.cassandra.db.TypeSizes; +import org.apache.cassandra.config.DatabaseDescriptor; import org.apache.cassandra.dht.IPartitioner; -import org.apache.cassandra.dht.IPartitionerDependentSerializer; +import org.apache.cassandra.dht.Murmur3Partitioner; +import org.apache.cassandra.dht.RandomPartitioner; import org.apache.cassandra.dht.Range; import org.apache.cassandra.dht.Token; import org.apache.cassandra.exceptions.ConfigurationException; -import org.apache.cassandra.io.IVersionedSerializer; import org.apache.cassandra.io.util.DataInputPlus; import org.apache.cassandra.io.util.DataOutputPlus; +import org.apache.cassandra.io.util.FileUtils; +import org.apache.cassandra.utils.concurrent.Ref; +import org.apache.cassandra.utils.memory.MemoryUtil; + +import static java.lang.String.format; +import static org.apache.cassandra.db.TypeSizes.sizeof; +import static org.apache.cassandra.utils.ByteBufferUtil.compare; +import static org.apache.cassandra.utils.MerkleTree.Difference.*; /** * A MerkleTree implemented as a binary tree. @@ -59,84 +69,45 @@ * If two MerkleTrees have the same hashdepth, they represent a perfect tree * of the same depth, and can always be compared, regardless of size or splits. */ -public class MerkleTree implements Serializable +public class MerkleTree { - private static Logger logger = LoggerFactory.getLogger(MerkleTree.class); + private static final Logger logger = LoggerFactory.getLogger(MerkleTree.class); - public static final MerkleTreeSerializer serializer = new MerkleTreeSerializer(); - private static final long serialVersionUID = 2L; + private static final int HASH_SIZE = 32; // 2xMM3_128 = 32 bytes. + private static final byte[] EMPTY_HASH = new byte[HASH_SIZE]; - public static final byte RECOMMENDED_DEPTH = Byte.MAX_VALUE - 1; + /* + * Thread-local byte array, large enough to host 32B of digest or MM3/Random partitoners' tokens + */ + private static final ThreadLocal byteArray = ThreadLocal.withInitial(() -> new byte[HASH_SIZE]); + + private static byte[] getTempArray(int minimumSize) + { + return minimumSize <= HASH_SIZE ? byteArray.get() : new byte[minimumSize]; + } - public static final int CONSISTENT = 0; - public static final int FULLY_INCONSISTENT = 1; - public static final int PARTIALLY_INCONSISTENT = 2; - private static final byte[] EMPTY_HASH = new byte[0]; + public static final byte RECOMMENDED_DEPTH = Byte.MAX_VALUE - 1; - public final byte hashdepth; + private final int hashdepth; /** The top level range that this MerkleTree covers. */ - public final Range fullRange; + final Range fullRange; private final IPartitioner partitioner; private long maxsize; private long size; - private Hashable root; + private Node root; - public static class MerkleTreeSerializer implements IVersionedSerializer + /** + * @param partitioner The partitioner in use. + * @param range the range this tree covers + * @param hashdepth The maximum depth of the tree. 100/(2^depth) is the % + * of the key space covered by each subrange of a fully populated tree. + * @param maxsize The maximum number of subranges in the tree. + */ + public MerkleTree(IPartitioner partitioner, Range range, int hashdepth, long maxsize) { - public void serialize(MerkleTree mt, DataOutputPlus out, int version) throws IOException - { - out.writeByte(mt.hashdepth); - out.writeLong(mt.maxsize); - out.writeLong(mt.size); - out.writeUTF(mt.partitioner.getClass().getCanonicalName()); - // full range - Token.serializer.serialize(mt.fullRange.left, out, version); - Token.serializer.serialize(mt.fullRange.right, out, version); - Hashable.serializer.serialize(mt.root, out, version); - } - - public MerkleTree deserialize(DataInputPlus in, int version) throws IOException - { - byte hashdepth = in.readByte(); - long maxsize = in.readLong(); - long size = in.readLong(); - IPartitioner partitioner; - try - { - partitioner = FBUtilities.newPartitioner(in.readUTF()); - } - catch (ConfigurationException e) - { - throw new IOException(e); - } - - // full range - Token left = Token.serializer.deserialize(in, partitioner, version); - Token right = Token.serializer.deserialize(in, partitioner, version); - Range fullRange = new Range<>(left, right); - - MerkleTree mt = new MerkleTree(partitioner, fullRange, hashdepth, maxsize); - mt.size = size; - mt.root = Hashable.serializer.deserialize(in, partitioner, version); - return mt; - } - - public long serializedSize(MerkleTree mt, int version) - { - long size = 1 // mt.hashdepth - + TypeSizes.sizeof(mt.maxsize) - + TypeSizes.sizeof(mt.size) - + TypeSizes.sizeof(mt.partitioner.getClass().getCanonicalName()); - - // full range - size += Token.serializer.serializedSize(mt.fullRange.left, version); - size += Token.serializer.serializedSize(mt.fullRange.right, version); - - size += Hashable.serializer.serializedSize(mt.root, version); - return size; - } + this(new OnHeapLeaf(), partitioner, range, hashdepth, maxsize, 1); } /** @@ -145,60 +116,56 @@ public long serializedSize(MerkleTree mt, int version) * @param hashdepth The maximum depth of the tree. 100/(2^depth) is the % * of the key space covered by each subrange of a fully populated tree. * @param maxsize The maximum number of subranges in the tree. + * @param size The size of the tree. Typically 1, unless deserilized from an existing tree */ - public MerkleTree(IPartitioner partitioner, Range range, byte hashdepth, long maxsize) + private MerkleTree(Node root, IPartitioner partitioner, Range range, int hashdepth, long maxsize, long size) { assert hashdepth < Byte.MAX_VALUE; + + this.root = root; this.fullRange = Preconditions.checkNotNull(range); this.partitioner = Preconditions.checkNotNull(partitioner); this.hashdepth = hashdepth; this.maxsize = maxsize; - - size = 1; - root = new Leaf(null); - } - - - static byte inc(byte in) - { - assert in < Byte.MAX_VALUE; - return (byte)(in + 1); + this.size = size; } /** * Initializes this tree by splitting it until hashdepth is reached, * or until an additional level of splits would violate maxsize. * - * NB: Replaces all nodes in the tree. + * NB: Replaces all nodes in the tree, and always builds on the heap */ public void init() { // determine the depth to which we can safely split the tree - byte sizedepth = (byte)(Math.log10(maxsize) / Math.log10(2)); - byte depth = (byte)Math.min(sizedepth, hashdepth); + int sizedepth = (int) (Math.log10(maxsize) / Math.log10(2)); + int depth = Math.min(sizedepth, hashdepth); - root = initHelper(fullRange.left, fullRange.right, (byte)0, depth); - size = (long)Math.pow(2, depth); + root = initHelper(fullRange.left, fullRange.right, 0, depth); + size = (long) Math.pow(2, depth); } - private Hashable initHelper(Token left, Token right, byte depth, byte max) + private OnHeapNode initHelper(Token left, Token right, int depth, int max) { if (depth == max) // we've reached the leaves - return new Leaf(); + return new OnHeapLeaf(); Token midpoint = partitioner.midpoint(left, right); if (midpoint.equals(left) || midpoint.equals(right)) - return new Leaf(); + return new OnHeapLeaf(); - Hashable lchild = initHelper(left, midpoint, inc(depth), max); - Hashable rchild = initHelper(midpoint, right, inc(depth), max); - return new Inner(midpoint, lchild, rchild); + OnHeapNode leftChild = initHelper(left, midpoint, depth + 1, max); + OnHeapNode rightChild = initHelper(midpoint, right, depth + 1, max); + return new OnHeapInner(midpoint, leftChild, rightChild); } - Hashable root() + public void release() { - return root; + if (root instanceof OffHeapNode) + ((OffHeapNode) root).release(); + root = null; } public IPartitioner partitioner() @@ -233,20 +200,21 @@ public void maxsize(long maxsize) public static List difference(MerkleTree ltree, MerkleTree rtree) { if (!ltree.fullRange.equals(rtree.fullRange)) - throw new IllegalArgumentException("Difference only make sense on tree covering the same range (but " + ltree.fullRange + " != " + rtree.fullRange + ")"); + throw new IllegalArgumentException("Difference only make sense on tree covering the same range (but " + ltree.fullRange + " != " + rtree.fullRange + ')'); + + // ensure on-heap trees' inner node hashes have been computed + ltree.fillInnerHashes(); + rtree.fillInnerHashes(); List diff = new ArrayList<>(); - TreeDifference active = new TreeDifference(ltree.fullRange.left, ltree.fullRange.right, (byte)0); + TreeRange active = new TreeRange(ltree.fullRange.left, ltree.fullRange.right, 0); - Hashable lnode = ltree.find(active); - Hashable rnode = rtree.find(active); - byte[] lhash = lnode.hash(); - byte[] rhash = rnode.hash(); - active.setSize(lnode.sizeOfRange(), rnode.sizeOfRange()); + Node lnode = ltree.root; + Node rnode = rtree.root; - if (lhash != null && rhash != null && !Arrays.equals(lhash, rhash)) + if (lnode.hashesDiffer(rnode)) { - if(lnode instanceof Leaf || rnode instanceof Leaf) + if (lnode instanceof Leaf || rnode instanceof Leaf) { logger.debug("Digest mismatch detected among leaf nodes {}, {}", lnode, rnode); diff.add(active); @@ -261,20 +229,20 @@ public static List difference(MerkleTree ltree, MerkleTree rtree) } } } - else if (lhash == null || rhash == null) - diff.add(active); + return diff; } + enum Difference { CONSISTENT, FULLY_INCONSISTENT, PARTIALLY_INCONSISTENT } + /** - * TODO: This function could be optimized into a depth first traversal of - * the two trees in parallel. + * TODO: This function could be optimized into a depth first traversal of the two trees in parallel. * * Takes two trees and a range for which they have hashes, but are inconsistent. * @return FULLY_INCONSISTENT if active is inconsistent, PARTIALLY_INCONSISTENT if only a subrange is inconsistent. */ @VisibleForTesting - static int differenceHelper(MerkleTree ltree, MerkleTree rtree, List diff, TreeRange active) + static Difference differenceHelper(MerkleTree ltree, MerkleTree rtree, List diff, TreeRange active) { if (active.depth == Byte.MAX_VALUE) return CONSISTENT; @@ -289,51 +257,46 @@ static int differenceHelper(MerkleTree ltree, MerkleTree rtree, List return FULLY_INCONSISTENT; } - TreeDifference left = new TreeDifference(active.left, midpoint, inc(active.depth)); - TreeDifference right = new TreeDifference(midpoint, active.right, inc(active.depth)); + TreeRange left = new TreeRange(active.left, midpoint, active.depth + 1); + TreeRange right = new TreeRange(midpoint, active.right, active.depth + 1); logger.debug("({}) Hashing sub-ranges [{}, {}] for {} divided by midpoint {}", active.depth, left, right, active, midpoint); - byte[] lhash, rhash; - Hashable lnode, rnode; + Node lnode, rnode; // see if we should recurse left lnode = ltree.find(left); rnode = rtree.find(left); - lhash = lnode.hash(); - rhash = rnode.hash(); - left.setSize(lnode.sizeOfRange(), rnode.sizeOfRange()); - left.setRows(lnode.rowsInRange(), rnode.rowsInRange()); - int ldiff = CONSISTENT; - boolean lreso = lhash != null && rhash != null; - if (lreso && !Arrays.equals(lhash, rhash)) + Difference ldiff = CONSISTENT; + if (null != lnode && null != rnode && lnode.hashesDiffer(rnode)) { logger.debug("({}) Inconsistent digest on left sub-range {}: [{}, {}]", active.depth, left, lnode, rnode); - if (lnode instanceof Leaf) ldiff = FULLY_INCONSISTENT; - else ldiff = differenceHelper(ltree, rtree, diff, left); + + if (lnode instanceof Leaf) + ldiff = FULLY_INCONSISTENT; + else + ldiff = differenceHelper(ltree, rtree, diff, left); } - else if (!lreso) + else if (null == lnode || null == rnode) { - logger.debug("({}) Left sub-range fully inconsistent {}", active.depth, right); + logger.debug("({}) Left sub-range fully inconsistent {}", active.depth, left); ldiff = FULLY_INCONSISTENT; } // see if we should recurse right lnode = ltree.find(right); rnode = rtree.find(right); - lhash = lnode.hash(); - rhash = rnode.hash(); - right.setSize(lnode.sizeOfRange(), rnode.sizeOfRange()); - right.setRows(lnode.rowsInRange(), rnode.rowsInRange()); - int rdiff = CONSISTENT; - boolean rreso = lhash != null && rhash != null; - if (rreso && !Arrays.equals(lhash, rhash)) + Difference rdiff = CONSISTENT; + if (null != lnode && null != rnode && lnode.hashesDiffer(rnode)) { logger.debug("({}) Inconsistent digest on right sub-range {}: [{}, {}]", active.depth, right, lnode, rnode); - if (rnode instanceof Leaf) rdiff = FULLY_INCONSISTENT; - else rdiff = differenceHelper(ltree, rtree, diff, right); + + if (rnode instanceof Leaf) + rdiff = FULLY_INCONSISTENT; + else + rdiff = differenceHelper(ltree, rtree, diff, right); } - else if (!rreso) + else if (null == lnode || null == rnode) { logger.debug("({}) Right sub-range fully inconsistent {}", active.depth, right); rdiff = FULLY_INCONSISTENT; @@ -362,133 +325,70 @@ else if (rdiff == FULLY_INCONSISTENT) } /** - * For testing purposes. - * Gets the smallest range containing the token. - */ - public TreeRange get(Token t) - { - return getHelper(root, fullRange.left, fullRange.right, (byte)0, t); - } - - TreeRange getHelper(Hashable hashable, Token pleft, Token pright, byte depth, Token t) - { - while (true) - { - if (hashable instanceof Leaf) - { - // we've reached a hash: wrap it up and deliver it - return new TreeRange(this, pleft, pright, depth, hashable); - } - // else: node. - - Inner node = (Inner) hashable; - depth = inc(depth); - if (Range.contains(pleft, node.token, t)) - { // left child contains token - hashable = node.lchild; - pright = node.token; - } - else - { // else: right child contains token - hashable = node.rchild; - pleft = node.token; - } - } - } - - /** - * Invalidates the ranges containing the given token. - * Useful for testing. - */ - public void invalidate(Token t) - { - invalidateHelper(root, fullRange.left, t); - } - - private void invalidateHelper(Hashable hashable, Token pleft, Token t) - { - hashable.hash(null); - if (hashable instanceof Leaf) - return; - // else: node. - - Inner node = (Inner)hashable; - if (Range.contains(pleft, node.token, t)) - // left child contains token - invalidateHelper(node.lchild, pleft, t); - else - // right child contains token - invalidateHelper(node.rchild, node.token, t); - } - - /** - * Hash the given range in the tree. The range must have been generated - * with recursive applications of partitioner.midpoint(). - * - * NB: Currently does not support wrapping ranges that do not end with - * partitioner.getMinimumToken(). - * - * @return Null if any subrange of the range is invalid, or if the exact - * range cannot be calculated using this tree. + * Exceptions that stop recursion early when we are sure that no answer + * can be found. */ - public byte[] hash(Range range) + static abstract class StopRecursion extends Exception { - return find(range).hash(); + static class TooDeep extends StopRecursion {} + static class BadRange extends StopRecursion {} } /** - * Find the {@link Hashable} node that matches the given {@code range}. + * Find the {@link Node} node that matches the given {@code range}. * * @param range Range to find - * @return {@link Hashable} found. If nothing found, return {@link Leaf} with null hash. + * @return {@link Node} found. If nothing found, return {@code null} */ - private Hashable find(Range range) + @VisibleForTesting + private Node find(Range range) { try { - return findHelper(root, new Range(fullRange.left, fullRange.right), range); + return findHelper(root, fullRange, range); } catch (StopRecursion e) { - return new Leaf(); + return null; } } /** * @throws StopRecursion If no match could be found for the range. */ - private Hashable findHelper(Hashable current, Range activeRange, Range find) throws StopRecursion + private Node findHelper(Node current, Range activeRange, Range find) throws StopRecursion { while (true) { if (current instanceof Leaf) { if (!find.contains(activeRange)) - // we are not fully contained in this range! - throw new StopRecursion.BadRange(); + throw new StopRecursion.BadRange(); // we are not fully contained in this range! + return current; } - // else: node. - Inner node = (Inner) current; - Range leftRange = new Range<>(activeRange.left, node.token); - Range rightRange = new Range<>(node.token, activeRange.right); + assert current instanceof Inner; + Inner inner = (Inner) current; - if (find.contains(activeRange)) - // this node is fully contained in the range - return node.calc(); + if (find.contains(activeRange)) // this node is fully contained in the range + return inner.fillInnerHashes(); + + Token midpoint = inner.token(); + Range leftRange = new Range<>(activeRange.left, midpoint); + Range rightRange = new Range<>(midpoint, activeRange.right); // else: one of our children contains the range - if (leftRange.contains(find)) - { // left child contains/matches the range - current = node.lchild; + if (leftRange.contains(find)) // left child contains/matches the range + { activeRange = leftRange; + current = inner.left(); } - else if (rightRange.contains(find)) - { // right child contains/matches the range - current = node.rchild; + else if (rightRange.contains(find)) // right child contains/matches the range + { activeRange = rightRange; + current = inner.right(); } else { @@ -506,12 +406,12 @@ else if (rightRange.contains(find)) */ public boolean split(Token t) { - if (!(size < maxsize)) + if (size >= maxsize) return false; try { - root = splitHelper(root, fullRange.left, fullRange.right, (byte)0, t); + root = splitHelper(root, fullRange.left, fullRange.right, 0, t); } catch (StopRecursion.TooDeep e) { @@ -520,12 +420,12 @@ public boolean split(Token t) return true; } - private Hashable splitHelper(Hashable hashable, Token pleft, Token pright, byte depth, Token t) throws StopRecursion.TooDeep + private OnHeapNode splitHelper(Node node, Token pleft, Token pright, int depth, Token t) throws StopRecursion.TooDeep { if (depth >= hashdepth) throw new StopRecursion.TooDeep(); - if (hashable instanceof Leaf) + if (node instanceof Leaf) { Token midpoint = partitioner.midpoint(pleft, pright); @@ -536,47 +436,47 @@ private Hashable splitHelper(Hashable hashable, Token pleft, Token pright, byte // split size++; - return new Inner(midpoint, new Leaf(), new Leaf()); + return new OnHeapInner(midpoint, new OnHeapLeaf(), new OnHeapLeaf()); } // else: node. // recurse on the matching child - Inner node = (Inner)hashable; + assert node instanceof OnHeapInner; + OnHeapInner inner = (OnHeapInner) node; - if (Range.contains(pleft, node.token, t)) - // left child contains token - node.lchild(splitHelper(node.lchild, pleft, node.token, inc(depth), t)); - else - // else: right child contains token - node.rchild(splitHelper(node.rchild, node.token, pright, inc(depth), t)); - return node; + if (Range.contains(pleft, inner.token(), t)) // left child contains token + inner.left(splitHelper(inner.left(), pleft, inner.token(), depth + 1, t)); + else // else: right child contains token + inner.right(splitHelper(inner.right(), inner.token(), pright, depth + 1, t)); + + return inner; } /** * Returns a lazy iterator of invalid TreeRanges that need to be filled * in order to make the given Range valid. */ - public TreeRangeIterator invalids() + TreeRangeIterator rangeIterator() { return new TreeRangeIterator(this); } - public EstimatedHistogram histogramOfRowSizePerLeaf() + EstimatedHistogram histogramOfRowSizePerLeaf() { HistogramBuilder histbuild = new HistogramBuilder(); for (TreeRange range : new TreeRangeIterator(this)) { - histbuild.add(range.hashable.sizeOfRange); + histbuild.add(range.node.sizeOfRange()); } return histbuild.buildWithStdevRangesAroundMean(); } - public EstimatedHistogram histogramOfRowCountPerLeaf() + EstimatedHistogram histogramOfRowCountPerLeaf() { HistogramBuilder histbuild = new HistogramBuilder(); for (TreeRange range : new TreeRangeIterator(this)) { - histbuild.add(range.hashable.rowsInRange); + histbuild.add(range.node.partitionsInRange()); } return histbuild.buildWithStdevRangesAroundMean(); } @@ -586,7 +486,7 @@ public long rowCount() long count = 0; for (TreeRange range : new TreeRangeIterator(this)) { - count += range.hashable.rowsInRange; + count += range.node.partitionsInRange(); } return count; } @@ -597,61 +497,23 @@ public String toString() StringBuilder buff = new StringBuilder(); buff.append("#"); + buff.append('>'); return buff.toString(); } - public static class TreeDifference extends TreeRange + @Override + public boolean equals(Object other) { - private static final long serialVersionUID = 6363654174549968183L; - - private long sizeOnLeft; - private long sizeOnRight; - private long rowsOnLeft; - private long rowsOnRight; - - void setSize(long sizeOnLeft, long sizeOnRight) - { - this.sizeOnLeft = sizeOnLeft; - this.sizeOnRight = sizeOnRight; - } - - void setRows(long rowsOnLeft, long rowsOnRight) - { - this.rowsOnLeft = rowsOnLeft; - this.rowsOnRight = rowsOnRight; - } - - public long sizeOnLeft() - { - return sizeOnLeft; - } - - public long sizeOnRight() - { - return sizeOnRight; - } - - public long rowsOnLeft() - { - return rowsOnLeft; - } - - public long rowsOnRight() - { - return rowsOnRight; - } - - public TreeDifference(Token left, Token right, byte depth) - { - super(null, left, right, depth, null); - } - - public long totalRows() - { - return rowsOnLeft + rowsOnRight; - } - + if (!(other instanceof MerkleTree)) + return false; + MerkleTree that = (MerkleTree) other; + + return this.root.equals(that.root) + && this.fullRange.equals(that.fullRange) + && this.partitioner == that.partitioner + && this.hashdepth == that.hashdepth + && this.maxsize == that.maxsize + && this.size == that.size; } /** @@ -664,28 +526,27 @@ public long totalRows() */ public static class TreeRange extends Range { - public static final long serialVersionUID = 1L; private final MerkleTree tree; - public final byte depth; - private final Hashable hashable; + public final int depth; + private final Node node; - TreeRange(MerkleTree tree, Token left, Token right, byte depth, Hashable hashable) + TreeRange(MerkleTree tree, Token left, Token right, int depth, Node node) { super(left, right); this.tree = tree; this.depth = depth; - this.hashable = hashable; + this.node = node; } - public void hash(byte[] hash) + TreeRange(Token left, Token right, int depth) { - assert tree != null : "Not intended for modification!"; - hashable.hash(hash); + this(null, left, right, depth, null); } - public byte[] hash() + public void hash(byte[] hash) { - return hashable.hash(); + assert tree != null : "Not intended for modification!"; + node.hash(hash); } /** @@ -693,33 +554,26 @@ public byte[] hash() */ public void addHash(RowHash entry) { - assert tree != null : "Not intended for modification!"; - assert hashable instanceof Leaf; - - hashable.addHash(entry.hash, entry.size); + addHash(entry.hash, entry.size); } - public void ensureHashInitialised() + void addHash(byte[] hash, long partitionSize) { assert tree != null : "Not intended for modification!"; - assert hashable instanceof Leaf; - if (hashable.hash == null) - hashable.hash = EMPTY_HASH; + assert node instanceof OnHeapLeaf; + ((OnHeapLeaf) node).addHash(hash, partitionSize); } public void addAll(Iterator entries) { - while (entries.hasNext()) - addHash(entries.next()); + while (entries.hasNext()) addHash(entries.next()); } @Override public String toString() { - StringBuilder buff = new StringBuilder("#").toString(); + return "#"; } + } + + public void serialize(DataOutputPlus out, int version) throws IOException + { + out.writeByte(hashdepth); + out.writeLong(maxsize); + out.writeLong(size); + out.writeUTF(partitioner.getClass().getCanonicalName()); + Token.serializer.serialize(fullRange.left, out, version); + Token.serializer.serialize(fullRange.right, out, version); + root.serialize(out, version); + } + + public long serializedSize(int version) + { + long size = 1 // mt.hashdepth + + sizeof(maxsize) + + sizeof(this.size) + + sizeof(partitioner.getClass().getCanonicalName()); + size += Token.serializer.serializedSize(fullRange.left, version); + size += Token.serializer.serializedSize(fullRange.right, version); + size += root.serializedSize(version); + return size; + } + + public static MerkleTree deserialize(DataInputPlus in, int version) throws IOException + { + return deserialize(in, DatabaseDescriptor.useOffheapMerkleTrees(), version); + } + + public static MerkleTree deserialize(DataInputPlus in, boolean offHeapRequested, int version) throws IOException + { + int hashDepth = in.readByte(); + long maxSize = in.readLong(); + int innerNodeCount = Ints.checkedCast(in.readLong()); - public Hashable rchild() + IPartitioner partitioner; + try { - return rchild; + partitioner = FBUtilities.newPartitioner(in.readUTF()); } - - public void lchild(Hashable child) + catch (ConfigurationException e) { - lchild = child; + throw new IOException(e); } - public void rchild(Hashable child) + Token left = Token.serializer.deserialize(in, partitioner, version); + Token right = Token.serializer.deserialize(in, partitioner, version); + Range fullRange = new Range<>(left, right); + Node root = deserializeTree(in, partitioner, innerNodeCount, offHeapRequested, version); + return new MerkleTree(root, partitioner, fullRange, hashDepth, maxSize, innerNodeCount); + } + + private static boolean shouldUseOffHeapTrees(IPartitioner partitioner, boolean offHeapRequested) + { + boolean offHeapSupported = partitioner instanceof Murmur3Partitioner || partitioner instanceof RandomPartitioner; + + if (offHeapRequested && !offHeapSupported && !warnedOnce) { - rchild = child; + logger.warn("Configuration requests off-heap merkle trees, but partitioner does not support it. Ignoring."); + warnedOnce = true; } - Hashable calc() + return offHeapRequested && offHeapSupported; + } + private static boolean warnedOnce; + + private static ByteBuffer allocate(int innerNodeCount, IPartitioner partitioner) + { + int size = offHeapBufferSize(innerNodeCount, partitioner); + logger.debug("Allocating direct buffer of size {} for an off-heap merkle tree", size); + ByteBuffer buffer = ByteBuffer.allocateDirect(size); + if (Ref.DEBUG_ENABLED) + MemoryUtil.setAttachment(buffer, new Ref<>(null, null)); + return buffer; + } + + private static Node deserializeTree(DataInputPlus in, IPartitioner partitioner, int innerNodeCount, boolean offHeapRequested, int version) throws IOException + { + return shouldUseOffHeapTrees(partitioner, offHeapRequested) + ? deserializeOffHeap(in, partitioner, innerNodeCount, version) + : OnHeapNode.deserialize(in, partitioner, version); + } + + /* + * Coordinating multiple trees from multiple replicas can get expensive. + * On the deserialization path, we know in advance what the tree looks like, + * So we can pre-size an offheap buffer and deserialize into that. + */ + MerkleTree tryMoveOffHeap() throws IOException + { + return root instanceof OnHeapNode && shouldUseOffHeapTrees(partitioner, DatabaseDescriptor.useOffheapMerkleTrees()) + ? moveOffHeap() + : this; + } + + @VisibleForTesting + MerkleTree moveOffHeap() throws IOException + { + assert root instanceof OnHeapNode; + root.fillInnerHashes(); // ensure on-heap trees' inner node hashes have been computed + ByteBuffer buffer = allocate(Ints.checkedCast(size), partitioner); + int pointer = ((OnHeapNode) root).serializeOffHeap(buffer, partitioner); + OffHeapNode newRoot = fromPointer(pointer, buffer, partitioner); + return new MerkleTree(newRoot, partitioner, fullRange, hashdepth, maxsize, size); + } + + private static OffHeapNode deserializeOffHeap(DataInputPlus in, IPartitioner partitioner, int innerNodeCount, int version) throws IOException + { + ByteBuffer buffer = allocate(innerNodeCount, partitioner); + int pointer = OffHeapNode.deserialize(in, buffer, partitioner, version); + return fromPointer(pointer, buffer, partitioner); + } + + private static OffHeapNode fromPointer(int pointer, ByteBuffer buffer, IPartitioner partitioner) + { + return pointer >= 0 ? new OffHeapInner(buffer, pointer, partitioner) : new OffHeapLeaf(buffer, ~pointer); + } + + private static int offHeapBufferSize(int innerNodeCount, IPartitioner partitioner) + { + return innerNodeCount * OffHeapInner.maxOffHeapSize(partitioner) + (innerNodeCount + 1) * OffHeapLeaf.maxOffHeapSize(); + } + + interface Node + { + byte[] hash(); + + boolean hasEmptyHash(); + + void hash(byte[] hash); + + boolean hashesDiffer(Node other); + + default Node fillInnerHashes() { - if (hash == null) - { - // hash and size haven't been calculated; calc children then compute - Hashable lnode = lchild.calc(); - Hashable rnode = rchild.calc(); - // cache the computed value - hash(lnode.hash, rnode.hash); - sizeOfRange = lnode.sizeOfRange + rnode.sizeOfRange; - rowsInRange = lnode.rowsInRange + rnode.rowsInRange; - } return this; } - /** - * Recursive toString. - */ - public void toString(StringBuilder buff, int maxdepth) + default long sizeOfRange() { - buff.append("#<").append(getClass().getSimpleName()); - buff.append(" ").append(token); - buff.append(" hash=").append(Hashable.toString(hash())); - buff.append(" children=["); - if (maxdepth < 1) - { - buff.append("#"); - } - else - { - if (lchild == null) - buff.append("null"); - else - lchild.toString(buff, maxdepth-1); - buff.append(" "); - if (rchild == null) - buff.append("null"); - else - rchild.toString(buff, maxdepth-1); - } - buff.append("]>"); + return 0; } - @Override - public String toString() + default long partitionsInRange() { - StringBuilder buff = new StringBuilder(); - toString(buff, 1); - return buff.toString(); + return 0; } - private static class InnerSerializer implements IPartitionerDependentSerializer - { - public void serialize(Inner inner, DataOutputPlus out, int version) throws IOException - { - Token.serializer.serialize(inner.token, out, version); - Hashable.serializer.serialize(inner.lchild, out, version); - Hashable.serializer.serialize(inner.rchild, out, version); - } + void serialize(DataOutputPlus out, int version) throws IOException; + int serializedSize(int version); - public Inner deserialize(DataInput in, IPartitioner p, int version) throws IOException - { - Token token = Token.serializer.deserialize(in, p, version); - Hashable lchild = Hashable.serializer.deserialize(in, p, version); - Hashable rchild = Hashable.serializer.deserialize(in, p, version); - return new Inner(token, lchild, rchild); - } + void toString(StringBuilder buff, int maxdepth); - public long serializedSize(Inner inner, int version) - { - return Token.serializer.serializedSize(inner.token, version) - + Hashable.serializer.serializedSize(inner.lchild, version) - + Hashable.serializer.serializedSize(inner.rchild, version); - } + static String toString(byte[] hash) + { + return hash == null + ? "null" + : '[' + Hex.bytesToHex(hash) + ']'; } + + boolean equals(Node node); } - /** - * A leaf node in the MerkleTree. Because the MerkleTree represents a much - * larger perfect binary tree of depth hashdepth, a Leaf object contains - * the value that would be contained in the perfect tree at its position. - * - * When rows are added to the MerkleTree using TreeRange.validate(), the - * tree extending below the Leaf is generated in memory, but only the root - * is stored in the Leaf. - */ - static class Leaf extends Hashable + static abstract class OnHeapNode implements Node { - public static final long serialVersionUID = 1L; - static final byte IDENT = 1; - private static final LeafSerializer serializer = new LeafSerializer(); + long sizeOfRange; + long partitionsInRange; - /** - * Constructs a null hash. - */ - public Leaf() + protected byte[] hash; + + OnHeapNode(byte[] hash) { - super(null); + if (hash == null) + throw new IllegalArgumentException(); + + this.hash = hash; } - public Leaf(byte[] hash) + public byte[] hash() { - super(hash); + return hash; } - public void toString(StringBuilder buff, int maxdepth) + public boolean hasEmptyHash() { - buff.append(toString()); + //noinspection ArrayEquality + return hash == EMPTY_HASH; } - @Override - public String toString() + public void hash(byte[] hash) { - return "#"; + if (hash == null) + throw new IllegalArgumentException(); + + this.hash = hash; } - private static class LeafSerializer implements IPartitionerDependentSerializer + public boolean hashesDiffer(Node other) { - public void serialize(Leaf leaf, DataOutputPlus out, int version) throws IOException - { - if (leaf.hash == null) - { - out.writeByte(-1); - } - else - { - out.writeByte(leaf.hash.length); - out.write(leaf.hash); - } - } + return other instanceof OnHeapNode + ? hashesDiffer( (OnHeapNode) other) + : hashesDiffer((OffHeapNode) other); + } + + private boolean hashesDiffer(OnHeapNode other) + { + return !Arrays.equals(hash(), other.hash()); + } + + private boolean hashesDiffer(OffHeapNode other) + { + return compare(hash(), other.buffer(), other.hashBytesOffset(), HASH_SIZE) != 0; + } + + @Override + public long sizeOfRange() + { + return sizeOfRange; + } + + @Override + public long partitionsInRange() + { + return partitionsInRange; + } + + static OnHeapNode deserialize(DataInputPlus in, IPartitioner p, int version) throws IOException + { + byte ident = in.readByte(); - public Leaf deserialize(DataInput in, IPartitioner p, int version) throws IOException + switch (ident) { - int hashLen = in.readByte(); - byte[] hash = hashLen < 0 ? null : new byte[hashLen]; - if (hash != null) - in.readFully(hash); - return new Leaf(hash); + case Inner.IDENT: + return OnHeapInner.deserializeWithoutIdent(in, p, version); + case Leaf.IDENT: + return OnHeapLeaf.deserializeWithoutIdent(in); + default: + throw new IOException("Unexpected node type: " + ident); } + } + + abstract int serializeOffHeap(ByteBuffer buffer, IPartitioner p) throws IOException; + } + + static abstract class OffHeapNode implements Node + { + protected final ByteBuffer buffer; + protected final int offset; + + OffHeapNode(ByteBuffer buffer, int offset) + { + this.buffer = buffer; + this.offset = offset; + } + + ByteBuffer buffer() + { + return buffer; + } + + public byte[] hash() + { + final int position = buffer.position(); + buffer.position(hashBytesOffset()); + byte[] array = new byte[HASH_SIZE]; + buffer.get(array); + buffer.position(position); + return array; + } + + public boolean hasEmptyHash() + { + return compare(buffer(), hashBytesOffset(), HASH_SIZE, EMPTY_HASH) == 0; + } + + public void hash(byte[] hash) + { + throw new UnsupportedOperationException(); + } + + public boolean hashesDiffer(Node other) + { + return other instanceof OnHeapNode + ? hashesDiffer((OnHeapNode) other) + : hashesDiffer((OffHeapNode) other); + } - public long serializedSize(Leaf leaf, int version) + private boolean hashesDiffer(OnHeapNode other) + { + return compare(buffer(), hashBytesOffset(), HASH_SIZE, other.hash()) != 0; + } + + private boolean hashesDiffer(OffHeapNode other) + { + int thisOffset = hashBytesOffset(); + int otherOffset = other.hashBytesOffset(); + + for (int i = 0; i < HASH_SIZE; i += 8) + if (buffer().getLong(thisOffset + i) != other.buffer().getLong(otherOffset + i)) + return true; + + return false; + } + + void release() + { + Object attachment = MemoryUtil.getAttachment(buffer); + if (attachment instanceof Ref) + ((Ref) attachment).release(); + FileUtils.clean(buffer); + } + + abstract int hashBytesOffset(); + + static int deserialize(DataInputPlus in, ByteBuffer buffer, IPartitioner p, int version) throws IOException + { + byte ident = in.readByte(); + + switch (ident) { - long size = 1; - if (leaf.hash != null) - size += leaf.hash().length; - return size; + case Inner.IDENT: + return OffHeapInner.deserializeWithoutIdent(in, buffer, p, version); + case Leaf.IDENT: + return OffHeapLeaf.deserializeWithoutIdent(in, buffer); + default: + throw new IOException("Unexpected node type: " + ident); } } } /** - * Hash value representing a row, to be used to pass hashes to the MerkleTree. - * The byte[] hash value should contain a digest of the key and value of the row - * created using a very strong hash function. + * A leaf node in the MerkleTree. Because the MerkleTree represents a much + * larger perfect binary tree of depth hashdepth, a Leaf object contains + * the value that would be contained in the perfect tree at its position. + * + * When rows are added to the MerkleTree using TreeRange.validate(), the + * tree extending below the Leaf is generated in memory, but only the root + * is stored in the Leaf. */ - public static class RowHash + interface Leaf extends Node { - public final Token token; - public final byte[] hash; - public final long size; - public RowHash(Token token, byte[] hash, long size) + static final byte IDENT = 1; + + default void serialize(DataOutputPlus out, int version) throws IOException { - this.token = token; - this.hash = hash; - this.size = size; + byte[] hash = hash(); + assert hash.length == HASH_SIZE; + + out.writeByte(Leaf.IDENT); + + if (!hasEmptyHash()) + { + out.writeByte(HASH_SIZE); + out.write(hash); + } + else + { + out.writeByte(0); + } + } + + default int serializedSize(int version) + { + return 2 + (hasEmptyHash() ? 0 : HASH_SIZE); + } + + default void toString(StringBuilder buff, int maxdepth) + { + buff.append(toString()); + } + + default boolean equals(Node other) + { + return other instanceof Leaf && !hashesDiffer(other); + } + } + + static class OnHeapLeaf extends OnHeapNode implements Leaf + { + OnHeapLeaf() + { + super(EMPTY_HASH); + } + + OnHeapLeaf(byte[] hash) + { + super(hash); + } + + /** + * Mixes the given value into our hash. If our hash is null, + * our hash will become the given value. + */ + void addHash(byte[] partitionHash, long partitionSize) + { + if (hasEmptyHash()) + hash(partitionHash); + else + xorOntoLeft(hash, partitionHash); + + sizeOfRange += partitionSize; + partitionsInRange += 1; + } + + static OnHeapLeaf deserializeWithoutIdent(DataInputPlus in) throws IOException + { + int size = in.readByte(); + switch (size) + { + case HASH_SIZE: + byte[] hash = new byte[HASH_SIZE]; + in.readFully(hash); + return new OnHeapLeaf(hash); + case 0: + return new OnHeapLeaf(); + default: + throw new IllegalStateException(format("Hash of size %d encountered, expecting %d or %d", size, HASH_SIZE, 0)); + } + } + + int serializeOffHeap(ByteBuffer buffer, IPartitioner p) + { + if (buffer.remaining() < OffHeapLeaf.maxOffHeapSize()) + throw new IllegalStateException("Insufficient remaining bytes to deserialize a Leaf node off-heap"); + + if (hash.length != HASH_SIZE) + throw new IllegalArgumentException("Hash of unexpected size when serializing a Leaf off-heap: " + hash.length); + + final int position = buffer.position(); + buffer.put(hash); + return ~position; + } + + @Override + public String toString() + { + return "#'; + } + } + + static class OffHeapLeaf extends OffHeapNode implements Leaf + { + static final int HASH_BYTES_OFFSET = 0; + + OffHeapLeaf(ByteBuffer buffer, int offset) + { + super(buffer, offset); + } + + public int hashBytesOffset() + { + return offset + HASH_BYTES_OFFSET; + } + + static int deserializeWithoutIdent(DataInput in, ByteBuffer buffer) throws IOException + { + if (buffer.remaining() < maxOffHeapSize()) + throw new IllegalStateException("Insufficient remaining bytes to deserialize a Leaf node off-heap"); + + final int position = buffer.position(); + + int hashLength = in.readByte(); + if (hashLength > 0) + { + if (hashLength != HASH_SIZE) + throw new IllegalStateException("Hash of unexpected size when deserializing an off-heap Leaf node: " + hashLength); + + byte[] hashBytes = getTempArray(HASH_SIZE); + in.readFully(hashBytes, 0, HASH_SIZE); + buffer.put(hashBytes, 0, HASH_SIZE); + } + else + { + buffer.put(EMPTY_HASH, 0, HASH_SIZE); + } + + return ~position; + } + + static int maxOffHeapSize() + { + return HASH_SIZE; } @Override public String toString() { - return "#"; + return "#'; } } /** - * Abstract class containing hashing logic, and containing a single hash field. + * An inner node in the MerkleTree. Inners can contain cached hash values, which + * are the binary hash of their two children. */ - static abstract class Hashable implements Serializable + interface Inner extends Node { - private static final long serialVersionUID = 1L; - private static final IPartitionerDependentSerializer serializer = new HashableSerializer(); + static final byte IDENT = 2; - protected byte[] hash; - protected long sizeOfRange; - protected long rowsInRange; + public Token token(); + + public Node left(); + public Node right(); - protected Hashable(byte[] hash) + default void serialize(DataOutputPlus out, int version) throws IOException { - this.hash = hash; + out.writeByte(Inner.IDENT); + Token.serializer.serialize(token(), out, version); + left().serialize(out, version); + right().serialize(out, version); } - public byte[] hash() + default int serializedSize(int version) { - return hash; + return 1 + + (int) Token.serializer.serializedSize(token(), version) + + left().serializedSize(version) + + right().serializedSize(version); } - public long sizeOfRange() + default void toString(StringBuilder buff, int maxdepth) { - return sizeOfRange; + buff.append("#<").append(getClass().getSimpleName()) + .append(' ').append(token()) + .append(" hash=").append(Node.toString(hash())) + .append(" children=["); + + if (maxdepth < 1) + { + buff.append('#'); + } + else + { + Node left = left(); + if (left == null) + buff.append("null"); + else + left.toString(buff, maxdepth - 1); + + buff.append(' '); + + Node right = right(); + if (right == null) + buff.append("null"); + else + right.toString(buff, maxdepth - 1); + } + + buff.append("]>"); } - public long rowsInRange() + default boolean equals(Node other) { - return rowsInRange; + if (!(other instanceof Inner)) + return false; + Inner that = (Inner) other; + return !hashesDiffer(other) && this.left().equals(that.left()) && this.right().equals(that.right()); } - void hash(byte[] hash) + default void unsafeInvalidate() { - this.hash = hash; } + } - Hashable calc() + static class OnHeapInner extends OnHeapNode implements Inner + { + private final Token token; + + private OnHeapNode left; + private OnHeapNode right; + + private boolean computed; + + OnHeapInner(Token token, OnHeapNode left, OnHeapNode right) { - return this; + super(EMPTY_HASH); + + this.token = token; + this.left = left; + this.right = right; } - /** - * Sets the value of this hash to binaryHash of its children. - * @param lefthash Hash of left child. - * @param righthash Hash of right child. - */ - void hash(byte[] lefthash, byte[] righthash) + public Token token() { - hash = binaryHash(lefthash, righthash); + return token; } - /** - * Mixes the given value into our hash. If our hash is null, - * our hash will become the given value. - */ - void addHash(byte[] righthash, long sizeOfRow) + public OnHeapNode left() { - if (hash == null) - hash = righthash; - else - hash = binaryHash(hash, righthash); - this.sizeOfRange += sizeOfRow; - this.rowsInRange += 1; + return left; } - /** - * The primitive with which all hashing should be accomplished: hashes - * a left and right value together. - */ - static byte[] binaryHash(final byte[] left, final byte[] right) + public OnHeapNode right() { - return FBUtilities.xor(left, right); + return right; } - public abstract void toString(StringBuilder buff, int maxdepth); + void left(OnHeapNode child) + { + left = child; + } - public static String toString(byte[] hash) + void right(OnHeapNode child) { - if (hash == null) - return "null"; - return "[" + Hex.bytesToHex(hash) + "]"; + right = child; } - private static class HashableSerializer implements IPartitionerDependentSerializer + @Override + public Node fillInnerHashes() { - public void serialize(Hashable h, DataOutputPlus out, int version) throws IOException + if (!computed) // hash and size haven't been calculated; compute children then compute this { - if (h instanceof Inner) - { - out.writeByte(Inner.IDENT); - Inner.serializer.serialize((Inner)h, out, version); - } - else if (h instanceof Leaf) - { - out.writeByte(Leaf.IDENT); - Leaf.serializer.serialize((Leaf) h, out, version); - } - else - throw new IOException("Unexpected Hashable: " + h.getClass().getCanonicalName()); - } + left.fillInnerHashes(); + right.fillInnerHashes(); - public Hashable deserialize(DataInput in, IPartitioner p, int version) throws IOException - { - byte ident = in.readByte(); - if (Inner.IDENT == ident) - return Inner.serializer.deserialize(in, p, version); - else if (Leaf.IDENT == ident) - return Leaf.serializer.deserialize(in, p, version); - else - throw new IOException("Unexpected Hashable: " + ident); + if (!left.hasEmptyHash() && !right.hasEmptyHash()) + hash = xor(left.hash(), right.hash()); + else if (left.hasEmptyHash()) + hash = right.hash(); + else if (right.hasEmptyHash()) + hash = left.hash(); + + sizeOfRange = left.sizeOfRange() + right.sizeOfRange(); + partitionsInRange = left.partitionsInRange() + right.partitionsInRange(); + + computed = true; } - public long serializedSize(Hashable h, int version) + return this; + } + + static OnHeapInner deserializeWithoutIdent(DataInputPlus in, IPartitioner p, int version) throws IOException + { + Token token = Token.serializer.deserialize(in, p, version); + OnHeapNode left = OnHeapNode.deserialize(in, p, version); + OnHeapNode right = OnHeapNode.deserialize(in, p, version); + return new OnHeapInner(token, left, right); + } + + int serializeOffHeap(ByteBuffer buffer, IPartitioner partitioner) throws IOException + { + if (buffer.remaining() < OffHeapInner.maxOffHeapSize(partitioner)) + throw new IllegalStateException("Insufficient remaining bytes to deserialize Inner node off-heap"); + + final int offset = buffer.position(); + + int tokenSize = partitioner.getTokenFactory().byteSize(token); + buffer.putShort(offset + OffHeapInner.TOKEN_LENGTH_OFFSET, Shorts.checkedCast(tokenSize)); + buffer.position(offset + OffHeapInner.TOKEN_BYTES_OFFSET); + partitioner.getTokenFactory().serialize(token, buffer); + + int leftPointer = left.serializeOffHeap(buffer, partitioner); + int rightPointer = right.serializeOffHeap(buffer, partitioner); + + buffer.putInt(offset + OffHeapInner.LEFT_CHILD_POINTER_OFFSET, leftPointer); + buffer.putInt(offset + OffHeapInner.RIGHT_CHILD_POINTER_OFFSET, rightPointer); + + int leftHashOffset = OffHeapInner.hashBytesOffset(leftPointer); + int rightHashOffset = OffHeapInner.hashBytesOffset(rightPointer); + + for (int i = 0; i < HASH_SIZE; i += 8) { - if (h instanceof Inner) - return 1 + Inner.serializer.serializedSize((Inner) h, version); - else if (h instanceof Leaf) - return 1 + Leaf.serializer.serializedSize((Leaf) h, version); - throw new AssertionError(h.getClass()); + buffer.putLong(offset + OffHeapInner.HASH_BYTES_OFFSET + i, + buffer.getLong(leftHashOffset + i) ^ buffer.getLong(rightHashOffset + i)); } + + return offset; + } + + @Override + public String toString() + { + StringBuilder buff = new StringBuilder(); + toString(buff, 1); + return buff.toString(); + } + + @Override + public void unsafeInvalidate() + { + computed = false; } } - /** - * Exceptions that stop recursion early when we are sure that no answer - * can be found. - */ - static abstract class StopRecursion extends Exception + static class OffHeapInner extends OffHeapNode implements Inner { - static class BadRange extends StopRecursion + /** + * All we want to keep here is just a pointer to the start of the Inner leaf in the + * direct buffer. From there, we'll be able to deserialize the following, in this order: + * + * 1. pointer to left child (int) + * 2. pointer to right child (int) + * 3. hash bytes (space allocated as HASH_MAX_SIZE) + * 4. token length (short) + * 5. token bytes (variable length) + */ + static final int LEFT_CHILD_POINTER_OFFSET = 0; + static final int RIGHT_CHILD_POINTER_OFFSET = 4; + static final int HASH_BYTES_OFFSET = 8; + static final int TOKEN_LENGTH_OFFSET = 8 + HASH_SIZE; + static final int TOKEN_BYTES_OFFSET = TOKEN_LENGTH_OFFSET + 2; + + private final IPartitioner partitioner; + + OffHeapInner(ByteBuffer buffer, int offset, IPartitioner partitioner) + { + super(buffer, offset); + this.partitioner = partitioner; + } + + public Token token() + { + int length = buffer.getShort(offset + TOKEN_LENGTH_OFFSET); + return partitioner.getTokenFactory().fromByteBuffer(buffer, offset + TOKEN_BYTES_OFFSET, length); + } + + public Node left() + { + return child(LEFT_CHILD_POINTER_OFFSET); + } + + public Node right() { - public BadRange(){ super(); } + return child(RIGHT_CHILD_POINTER_OFFSET); } - static class InvalidHash extends StopRecursion + private Node child(int childOffset) { - public InvalidHash(){ super(); } + int pointer = buffer.getInt(offset + childOffset); + return pointer >= 0 ? new OffHeapInner(buffer, pointer, partitioner) : new OffHeapLeaf(buffer, ~pointer); } - static class TooDeep extends StopRecursion + public int hashBytesOffset() { - public TooDeep(){ super(); } + return offset + HASH_BYTES_OFFSET; } + + static int deserializeWithoutIdent(DataInputPlus in, ByteBuffer buffer, IPartitioner partitioner, int version) throws IOException + { + if (buffer.remaining() < maxOffHeapSize(partitioner)) + throw new IllegalStateException("Insufficient remaining bytes to deserialize Inner node off-heap"); + + final int offset = buffer.position(); + + int tokenSize = Token.serializer.deserializeSize(in); + byte[] tokenBytes = getTempArray(tokenSize); + in.readFully(tokenBytes, 0, tokenSize); + + buffer.putShort(offset + OffHeapInner.TOKEN_LENGTH_OFFSET, Shorts.checkedCast(tokenSize)); + buffer.position(offset + OffHeapInner.TOKEN_BYTES_OFFSET); + buffer.put(tokenBytes, 0, tokenSize); + + int leftPointer = OffHeapNode.deserialize(in, buffer, partitioner, version); + int rightPointer = OffHeapNode.deserialize(in, buffer, partitioner, version); + + buffer.putInt(offset + OffHeapInner.LEFT_CHILD_POINTER_OFFSET, leftPointer); + buffer.putInt(offset + OffHeapInner.RIGHT_CHILD_POINTER_OFFSET, rightPointer); + + int leftHashOffset = hashBytesOffset(leftPointer); + int rightHashOffset = hashBytesOffset(rightPointer); + + for (int i = 0; i < HASH_SIZE; i += 8) + { + buffer.putLong(offset + OffHeapInner.HASH_BYTES_OFFSET + i, + buffer.getLong(leftHashOffset + i) ^ buffer.getLong(rightHashOffset + i)); + } + + return offset; + } + + static int maxOffHeapSize(IPartitioner partitioner) + { + return 4 // left pointer + + 4 // right pointer + + HASH_SIZE + + 2 + partitioner.getMaxTokenSize(); + } + + static int hashBytesOffset(int pointer) + { + return pointer >= 0 ? pointer + OffHeapInner.HASH_BYTES_OFFSET : ~pointer + OffHeapLeaf.HASH_BYTES_OFFSET; + } + + @Override + public String toString() + { + StringBuilder buff = new StringBuilder(); + toString(buff, 1); + return buff.toString(); + } + } + + /** + * @return The bitwise XOR of the inputs. + */ + static byte[] xor(byte[] left, byte[] right) + { + assert left.length == right.length; + + byte[] out = Arrays.copyOf(right, right.length); + for (int i = 0; i < left.length; i++) + out[i] = (byte)((left[i] & 0xFF) ^ (right[i] & 0xFF)); + return out; + } + + /** + * Bitwise XOR of the inputs, in place on the left array. + */ + private static void xorOntoLeft(byte[] left, byte[] right) + { + assert left.length == right.length; + + for (int i = 0; i < left.length; i++) + left[i] = (byte) ((left[i] & 0xFF) ^ (right[i] & 0xFF)); } /** @@ -1183,10 +1512,10 @@ public static int estimatedMaxDepthForBytes(IPartitioner partitioner, long numBy { byte[] hashLeft = new byte[bytesPerHash]; byte[] hashRigth = new byte[bytesPerHash]; - Leaf left = new Leaf(hashLeft); - Leaf right = new Leaf(hashRigth); - Inner inner = new Inner(partitioner.getMinimumToken(), left, right); - inner.calc(); + OnHeapLeaf left = new OnHeapLeaf(hashLeft); + OnHeapLeaf right = new OnHeapLeaf(hashRigth); + Inner inner = new OnHeapInner(partitioner.getMinimumToken(), left, right); + inner.fillInnerHashes(); // Some partioners have variable token sizes, try to estimate as close as we can by using the same // heap estimate as the memtables use. @@ -1201,4 +1530,124 @@ public static int estimatedMaxDepthForBytes(IPartitioner partitioner, long numBy long adjustedBytes = Math.max(1, (numBytes + sizeOfInner) / (sizeOfLeaf + sizeOfInner)); return Math.max(1, (int) Math.floor(Math.log(adjustedBytes) / Math.log(2))); } + + /* + * Test-only methods. + */ + + /** + * Invalidates the ranges containing the given token. + * Useful for testing. + */ + @VisibleForTesting + void unsafeInvalidate(Token t) + { + unsafeInvalidateHelper(root, fullRange.left, t); + } + + private void unsafeInvalidateHelper(Node node, Token pleft, Token t) + { + node.hash(EMPTY_HASH); + + if (node instanceof Leaf) + return; + + assert node instanceof Inner; + Inner inner = (Inner) node; + inner.unsafeInvalidate(); + + if (Range.contains(pleft, inner.token(), t)) + unsafeInvalidateHelper(inner.left(), pleft, t); // left child contains token + else + unsafeInvalidateHelper(inner.right(), inner.token(), t); // right child contains token + } + + /** + * Hash the given range in the tree. The range must have been generated + * with recursive applications of partitioner.midpoint(). + * + * NB: Currently does not support wrapping ranges that do not end with + * partitioner.getMinimumToken(). + * + * @return {@link #EMPTY_HASH} if any subrange of the range is invalid, or if the exact + * range cannot be calculated using this tree. + */ + @VisibleForTesting + byte[] hash(Range range) + { + return find(range).hash(); + } + + interface Consumer + { + void accept(Node node) throws E; + } + + @VisibleForTesting + boolean ifHashesRange(Range range, Consumer consumer) throws E + { + try + { + Node node = findHelper(root, new Range<>(fullRange.left, fullRange.right), range); + boolean hasHash = !node.hasEmptyHash(); + if (hasHash) + consumer.accept(node); + return hasHash; + } + catch (StopRecursion e) + { + return false; + } + } + + @VisibleForTesting + boolean hashesRange(Range range) + { + return ifHashesRange(range, n -> {}); + } + + /** + * For testing purposes. + * Gets the smallest range containing the token. + */ + @VisibleForTesting + public TreeRange get(Token t) + { + return getHelper(root, fullRange.left, fullRange.right, t); + } + + private TreeRange getHelper(Node node, Token pleft, Token pright, Token t) + { + int depth = 0; + + while (true) + { + if (node instanceof Leaf) + { + // we've reached a hash: wrap it up and deliver it + return new TreeRange(this, pleft, pright, depth, node); + } + + assert node instanceof Inner; + Inner inner = (Inner) node; + + if (Range.contains(pleft, inner.token(), t)) // left child contains token + { + pright = inner.token(); + node = inner.left(); + } + else // right child contains token + { + pleft = inner.token(); + node = inner.right(); + } + + depth++; + } + } + + private void fillInnerHashes() + { + root.fillInnerHashes(); + } } diff --git a/src/java/org/apache/cassandra/utils/MerkleTrees.java b/src/java/org/apache/cassandra/utils/MerkleTrees.java index d2a80583af9a..0043fe07ad5a 100644 --- a/src/java/org/apache/cassandra/utils/MerkleTrees.java +++ b/src/java/org/apache/cassandra/utils/MerkleTrees.java @@ -44,9 +44,9 @@ public class MerkleTrees implements Iterable, MerkleTree> { public static final MerkleTreesSerializer serializer = new MerkleTreesSerializer(); - private Map, MerkleTree> merkleTrees = new TreeMap<>(new TokenRangeComparator()); + private final Map, MerkleTree> merkleTrees = new TreeMap<>(new TokenRangeComparator()); - private IPartitioner partitioner; + private final IPartitioner partitioner; /** * Creates empty MerkleTrees object. @@ -142,6 +142,15 @@ public void init() } } + /** + * Dereference all merkle trees and release direct memory for all off-heap trees. + */ + public void release() + { + merkleTrees.values().forEach(MerkleTree::release); + merkleTrees.clear(); + } + /** * Init a selected MerkleTree with an even tree distribution. * @@ -171,7 +180,7 @@ public boolean split(Token t) @VisibleForTesting public void invalidate(Token t) { - getMerkleTree(t).invalidate(t); + getMerkleTree(t).unsafeInvalidate(t); } /** @@ -247,11 +256,11 @@ private boolean validateNonOverlapping(MerkleTree tree) } /** - * Get an iterator for all the invalids generated by the MerkleTrees. + * Get an iterator for all the iterator generated by the MerkleTrees. * * @return */ - public TreeRangeIterator invalids() + public TreeRangeIterator rangeIterator() { return new TreeRangeIterator(); } @@ -285,30 +294,20 @@ public void logRowSizePerLeaf(Logger logger) @VisibleForTesting public byte[] hash(Range range) { - ByteArrayOutputStream baos = new ByteArrayOutputStream(); - boolean hashed = false; - - try + try (ByteArrayOutputStream baos = new ByteArrayOutputStream()) { - for (Range rt : merkleTrees.keySet()) - { - if (rt.intersects(range)) - { - byte[] bytes = merkleTrees.get(rt).hash(range); - if (bytes != null) - { - baos.write(bytes); - hashed = true; - } - } - } + boolean hashed = false; + + for (Map.Entry, MerkleTree> entry : merkleTrees.entrySet()) + if (entry.getKey().intersects(range)) + hashed |= entry.getValue().ifHashesRange(range, n -> baos.write(n.hash())); + + return hashed ? baos.toByteArray() : null; } catch (IOException e) { throw new RuntimeException("Unable to append merkle tree hash to result"); } - - return hashed ? baos.toByteArray() : null; } /** @@ -354,7 +353,7 @@ private MerkleTree.TreeRange nextIterator() { if (it.hasNext()) { - current = it.next().invalids(); + current = it.next().rangeIterator(); return current.next(); } @@ -368,6 +367,17 @@ public Iterator iterator() } } + /** + * @return a new {@link MerkleTrees} instance with all trees moved off heap. + */ + public MerkleTrees tryMoveOffHeap() throws IOException + { + Map, MerkleTree> movedTrees = new TreeMap<>(new TokenRangeComparator()); + for (Map.Entry, MerkleTree> entry : merkleTrees.entrySet()) + movedTrees.put(entry.getKey(), entry.getValue().tryMoveOffHeap()); + return new MerkleTrees(partitioner, movedTrees.values()); + } + /** * Get the differences between the two sets of MerkleTrees. * @@ -379,9 +389,7 @@ public static List> difference(MerkleTrees ltree, MerkleTrees rtree { List> differences = new ArrayList<>(); for (MerkleTree tree : ltree.merkleTrees.values()) - { differences.addAll(MerkleTree.difference(tree, rtree.getMerkleTree(tree.fullRange))); - } return differences; } @@ -392,7 +400,7 @@ public void serialize(MerkleTrees trees, DataOutputPlus out, int version) throws out.writeInt(trees.merkleTrees.size()); for (MerkleTree tree : trees.merkleTrees.values()) { - MerkleTree.serializer.serialize(tree, out, version); + tree.serialize(out, version); } } @@ -405,7 +413,7 @@ public MerkleTrees deserialize(DataInputPlus in, int version) throws IOException { for (int i = 0; i < nTrees; i++) { - MerkleTree tree = MerkleTree.serializer.deserialize(in, version); + MerkleTree tree = MerkleTree.deserialize(in, version); trees.add(tree); if (partitioner == null) @@ -425,7 +433,7 @@ public long serializedSize(MerkleTrees trees, int version) long size = TypeSizes.sizeof(trees.merkleTrees.size()); for (MerkleTree tree : trees.merkleTrees.values()) { - size += MerkleTree.serializer.serializedSize(tree, version); + size += tree.serializedSize(version); } return size; } diff --git a/src/java/org/apache/cassandra/utils/MonotonicClock.java b/src/java/org/apache/cassandra/utils/MonotonicClock.java new file mode 100644 index 000000000000..5a1aa3c0361e --- /dev/null +++ b/src/java/org/apache/cassandra/utils/MonotonicClock.java @@ -0,0 +1,346 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.cassandra.utils; + +import java.lang.reflect.Constructor; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.function.LongSupplier; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.cassandra.concurrent.ScheduledExecutors; +import org.apache.cassandra.config.Config; + +import static java.util.concurrent.TimeUnit.MILLISECONDS; + +/** + * Wrapper around time related functions that are either implemented by using the default JVM calls + * or by using a custom implementation for testing purposes. + * + * See {@link #preciseTime} for how to use a custom implementation. + * + * Please note that {@link java.time.Clock} wasn't used, as it would not be possible to provide an + * implementation for {@link #now()} with the exact same properties of {@link System#nanoTime()}. + */ +public interface MonotonicClock +{ + /** + * Static singleton object that will be instantiated by default with a system clock + * implementation. Set cassandra.clock system property to a FQCN to use a + * different implementation instead. + */ + public static final MonotonicClock preciseTime = Defaults.precise(); + public static final MonotonicClock approxTime = Defaults.approx(preciseTime); + + /** + * @see System#nanoTime() + * + * Provides a monotonic time that can be compared with any other such value produced by the same clock + * since the application started only; these times cannot be persisted or serialized to other nodes. + * + * Nanosecond precision. + */ + public long now(); + + /** + * @return nanoseconds of potential error + */ + public long error(); + + public MonotonicClockTranslation translate(); + + public boolean isAfter(long instant); + public boolean isAfter(long now, long instant); + + static class Defaults + { + private static final Logger logger = LoggerFactory.getLogger(MonotonicClock.class); + + private static MonotonicClock precise() + { + String sclock = System.getProperty("cassandra.clock"); + if (sclock == null) + sclock = System.getProperty("cassandra.monotonic_clock.precise"); + + if (sclock != null) + { + try + { + logger.debug("Using custom clock implementation: {}", sclock); + return (MonotonicClock) Class.forName(sclock).newInstance(); + } + catch (Exception e) + { + logger.error(e.getMessage(), e); + } + } + + return new SystemClock(); + } + + private static MonotonicClock approx(MonotonicClock precise) + { + String sclock = System.getProperty("cassandra.monotonic_clock.approx"); + if (sclock != null) + { + try + { + logger.debug("Using custom clock implementation: {}", sclock); + Class clazz = (Class) Class.forName(sclock); + + if (SystemClock.class.equals(clazz) && SystemClock.class.equals(precise.getClass())) + return precise; + + try + { + Constructor withPrecise = clazz.getConstructor(MonotonicClock.class); + return withPrecise.newInstance(precise); + } + catch (NoSuchMethodException nme) + { + } + + return clazz.newInstance(); + } + catch (Exception e) + { + logger.error(e.getMessage(), e); + } + } + + return new SampledClock(precise); + } + } + + static abstract class AbstractEpochSamplingClock implements MonotonicClock + { + private static final Logger logger = LoggerFactory.getLogger(AbstractEpochSamplingClock.class); + private static final String UPDATE_INTERVAL_PROPERTY = Config.PROPERTY_PREFIX + "NANOTIMETOMILLIS_TIMESTAMP_UPDATE_INTERVAL"; + private static final long UPDATE_INTERVAL_MS = Long.getLong(UPDATE_INTERVAL_PROPERTY, 10000); + + private static class AlmostSameTime implements MonotonicClockTranslation + { + final long millisSinceEpoch; + final long monotonicNanos; + final long error; // maximum error of millis measurement (in nanos) + + private AlmostSameTime(long millisSinceEpoch, long monotonicNanos, long errorNanos) + { + this.millisSinceEpoch = millisSinceEpoch; + this.monotonicNanos = monotonicNanos; + this.error = errorNanos; + } + + public long fromMillisSinceEpoch(long currentTimeMillis) + { + return monotonicNanos + MILLISECONDS.toNanos(currentTimeMillis - millisSinceEpoch); + } + + public long toMillisSinceEpoch(long nanoTime) + { + return millisSinceEpoch + TimeUnit.NANOSECONDS.toMillis(nanoTime - monotonicNanos); + } + + public long error() + { + return error; + } + } + + final LongSupplier millisSinceEpoch; + + private volatile AlmostSameTime almostSameTime = new AlmostSameTime(0L, 0L, Long.MAX_VALUE); + private Future almostSameTimeUpdater; + private static double failedAlmostSameTimeUpdateModifier = 1.0; + + AbstractEpochSamplingClock(LongSupplier millisSinceEpoch) + { + this.millisSinceEpoch = millisSinceEpoch; + resumeEpochSampling(); + } + + public MonotonicClockTranslation translate() + { + return almostSameTime; + } + + public synchronized void pauseEpochSampling() + { + if (almostSameTimeUpdater == null) + return; + + almostSameTimeUpdater.cancel(true); + try { almostSameTimeUpdater.get(); } catch (Throwable t) { } + almostSameTimeUpdater = null; + } + + public synchronized void resumeEpochSampling() + { + if (almostSameTimeUpdater != null) + throw new IllegalStateException("Already running"); + updateAlmostSameTime(); + logger.info("Scheduling approximate time conversion task with an interval of {} milliseconds", UPDATE_INTERVAL_MS); + almostSameTimeUpdater = ScheduledExecutors.scheduledFastTasks.scheduleWithFixedDelay(this::updateAlmostSameTime, UPDATE_INTERVAL_MS, UPDATE_INTERVAL_MS, MILLISECONDS); + } + + private void updateAlmostSameTime() + { + final int tries = 3; + long[] samples = new long[2 * tries + 1]; + samples[0] = System.nanoTime(); + for (int i = 1 ; i < samples.length ; i += 2) + { + samples[i] = millisSinceEpoch.getAsLong(); + samples[i + 1] = now(); + } + + int best = 1; + // take sample with minimum delta between calls + for (int i = 3 ; i < samples.length - 1 ; i += 2) + { + if ((samples[i+1] - samples[i-1]) < (samples[best+1]-samples[best-1])) + best = i; + } + + long millis = samples[best]; + long nanos = (samples[best+1] / 2) + (samples[best-1] / 2); + long error = (samples[best+1] / 2) - (samples[best-1] / 2); + + AlmostSameTime prev = almostSameTime; + AlmostSameTime next = new AlmostSameTime(millis, nanos, error); + + if (next.error > prev.error && next.error > prev.error * failedAlmostSameTimeUpdateModifier) + { + failedAlmostSameTimeUpdateModifier *= 1.1; + return; + } + + failedAlmostSameTimeUpdateModifier = 1.0; + almostSameTime = next; + } + } + + public static class SystemClock extends AbstractEpochSamplingClock + { + private SystemClock() + { + super(System::currentTimeMillis); + } + + @Override + public long now() + { + return System.nanoTime(); + } + + @Override + public long error() + { + return 1; + } + + @Override + public boolean isAfter(long instant) + { + return now() > instant; + } + + @Override + public boolean isAfter(long now, long instant) + { + return now > instant; + } + } + + public static class SampledClock implements MonotonicClock + { + private static final Logger logger = LoggerFactory.getLogger(SampledClock.class); + private static final int UPDATE_INTERVAL_MS = Math.max(1, Integer.parseInt(System.getProperty(Config.PROPERTY_PREFIX + "approximate_time_precision_ms", "2"))); + private static final long ERROR_NANOS = MILLISECONDS.toNanos(UPDATE_INTERVAL_MS); + + private final MonotonicClock precise; + + private volatile long almostNow; + private Future almostNowUpdater; + + public SampledClock(MonotonicClock precise) + { + this.precise = precise; + resumeNowSampling(); + } + + @Override + public long now() + { + return almostNow; + } + + @Override + public long error() + { + return ERROR_NANOS; + } + + @Override + public MonotonicClockTranslation translate() + { + return precise.translate(); + } + + @Override + public boolean isAfter(long instant) + { + return isAfter(almostNow, instant); + } + + @Override + public boolean isAfter(long now, long instant) + { + return now - ERROR_NANOS > instant; + } + + public synchronized void pauseNowSampling() + { + if (almostNowUpdater == null) + return; + + almostNowUpdater.cancel(true); + try { almostNowUpdater.get(); } catch (Throwable t) { } + almostNowUpdater = null; + } + + public synchronized void resumeNowSampling() + { + if (almostNowUpdater != null) + throw new IllegalStateException("Already running"); + + almostNow = precise.now(); + logger.info("Scheduling approximate time-check task with a precision of {} milliseconds", UPDATE_INTERVAL_MS); + almostNowUpdater = ScheduledExecutors.scheduledFastTasks.scheduleWithFixedDelay(() -> almostNow = precise.now(), UPDATE_INTERVAL_MS, UPDATE_INTERVAL_MS, MILLISECONDS); + } + + public synchronized void refreshNow() + { + pauseNowSampling(); + resumeNowSampling(); + } + } + +} diff --git a/src/java/org/apache/cassandra/utils/MonotonicClockTranslation.java b/src/java/org/apache/cassandra/utils/MonotonicClockTranslation.java new file mode 100644 index 000000000000..f7f83e471109 --- /dev/null +++ b/src/java/org/apache/cassandra/utils/MonotonicClockTranslation.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.utils; + +public interface MonotonicClockTranslation +{ + /** accepts millis since epoch, returns nanoTime in the related clock */ + public long fromMillisSinceEpoch(long currentTimeMillis); + /** accepts nanoTime in the related MonotinicClock, returns millis since epoch */ + public long toMillisSinceEpoch(long nanoTime); + /** Nanoseconds of probable error in the translation */ + public long error(); +} diff --git a/src/java/org/apache/cassandra/utils/NanoTimeToCurrentTimeMillis.java b/src/java/org/apache/cassandra/utils/NanoTimeToCurrentTimeMillis.java deleted file mode 100644 index 5aafbe55dc2f..000000000000 --- a/src/java/org/apache/cassandra/utils/NanoTimeToCurrentTimeMillis.java +++ /dev/null @@ -1,72 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.cassandra.utils; - -import java.util.concurrent.TimeUnit; - -import org.apache.cassandra.concurrent.ScheduledExecutors; -import org.apache.cassandra.config.Config; - -/* - * Convert from nanotime to non-monotonic current time millis. Beware of weaker ordering guarantees. - */ -public class NanoTimeToCurrentTimeMillis -{ - /* - * How often to pull a new timestamp from the system. - */ - private static final String TIMESTAMP_UPDATE_INTERVAL_PROPERTY = Config.PROPERTY_PREFIX + "NANOTIMETOMILLIS_TIMESTAMP_UPDATE_INTERVAL"; - private static final long TIMESTAMP_UPDATE_INTERVAL = Long.getLong(TIMESTAMP_UPDATE_INTERVAL_PROPERTY, 10000); - - private static volatile long TIMESTAMP_BASE[] = new long[] { System.currentTimeMillis(), System.nanoTime() }; - - /* - * System.currentTimeMillis() is 25 nanoseconds. This is 2 nanoseconds (maybe) according to JMH. - * Faster than calling both currentTimeMillis() and nanoTime(). - * - * There is also the issue of how scalable nanoTime() and currentTimeMillis() are which is a moving target. - * - * These timestamps don't order with System.currentTimeMillis() because currentTimeMillis() can tick over - * before this one does. I have seen it behind by as much as 2ms on Linux and 25ms on Windows. - */ - public static long convert(long nanoTime) - { - final long timestampBase[] = TIMESTAMP_BASE; - return timestampBase[0] + TimeUnit.NANOSECONDS.toMillis(nanoTime - timestampBase[1]); - } - - public static void updateNow() - { - ScheduledExecutors.scheduledFastTasks.submit(NanoTimeToCurrentTimeMillis::updateTimestampBase); - } - - static - { - ScheduledExecutors.scheduledFastTasks.scheduleWithFixedDelay(NanoTimeToCurrentTimeMillis::updateTimestampBase, - TIMESTAMP_UPDATE_INTERVAL, - TIMESTAMP_UPDATE_INTERVAL, - TimeUnit.MILLISECONDS); - } - - private static void updateTimestampBase() - { - TIMESTAMP_BASE = new long[] { - Math.max(TIMESTAMP_BASE[0], System.currentTimeMillis()), - Math.max(TIMESTAMP_BASE[1], System.nanoTime()) }; - } -} diff --git a/src/java/org/apache/cassandra/utils/Throwables.java b/src/java/org/apache/cassandra/utils/Throwables.java index ec4e5c849000..5d6d96fac60c 100644 --- a/src/java/org/apache/cassandra/utils/Throwables.java +++ b/src/java/org/apache/cassandra/utils/Throwables.java @@ -23,6 +23,7 @@ import java.util.Arrays; import java.util.Iterator; import java.util.Optional; +import java.util.function.Predicate; import java.util.stream.Stream; import org.apache.cassandra.io.FSReadError; @@ -37,6 +38,11 @@ public interface DiscreteAction void perform() throws E; } + public static boolean isCausedBy(Throwable t, Predicate cause) + { + return cause.test(t) || (t.getCause() != null && cause.test(t.getCause())); + } + public static T merge(T existingFail, T newFail) { if (existingFail == null) @@ -74,6 +80,12 @@ public static boolean failIfCanCast(Throwable fail, Class< return true; } + @SafeVarargs + public static void maybeFail(DiscreteAction ... actions) + { + maybeFail(Throwables.perform(null, Stream.of(actions))); + } + @SafeVarargs public static void perform(DiscreteAction ... actions) throws E { @@ -88,7 +100,7 @@ public static void perform(Stream void perform(Stream> actions) throws E { - Throwable fail = perform((Throwable) null, actions); + Throwable fail = perform(null, actions); if (failIfCanCast(fail, null)) throw (E) fail; } diff --git a/src/java/org/apache/cassandra/utils/concurrent/Ref.java b/src/java/org/apache/cassandra/utils/concurrent/Ref.java index 3c1b7cc1cdf7..a3733474457e 100644 --- a/src/java/org/apache/cassandra/utils/concurrent/Ref.java +++ b/src/java/org/apache/cassandra/utils/concurrent/Ref.java @@ -45,12 +45,15 @@ import org.apache.cassandra.io.sstable.format.SSTableReader; import org.apache.cassandra.io.util.Memory; import org.apache.cassandra.io.util.SafeMemory; +import org.apache.cassandra.utils.ExecutorUtils; import org.apache.cassandra.utils.NoSpamLogger; import org.apache.cassandra.utils.Pair; import org.cliffc.high_scale_lib.NonBlockingHashMap; import static java.util.Collections.emptyList; +import static org.apache.cassandra.utils.ExecutorUtils.awaitTermination; +import static org.apache.cassandra.utils.ExecutorUtils.shutdownNow; import static org.apache.cassandra.utils.Throwables.maybeFail; import static org.apache.cassandra.utils.Throwables.merge; @@ -705,14 +708,8 @@ private void removeExpected(Set candidates) } @VisibleForTesting - public static void shutdownReferenceReaper() throws InterruptedException + public static void shutdownReferenceReaper(long timeout, TimeUnit unit) throws InterruptedException, TimeoutException { - EXEC.shutdown(); - EXEC.awaitTermination(60, TimeUnit.SECONDS); - if (STRONG_LEAK_DETECTOR != null) - { - STRONG_LEAK_DETECTOR.shutdownNow(); - STRONG_LEAK_DETECTOR.awaitTermination(60, TimeUnit.SECONDS); - } + ExecutorUtils.shutdownNowAndWait(timeout, unit, EXEC, STRONG_LEAK_DETECTOR); } } diff --git a/src/java/org/apache/cassandra/utils/concurrent/SimpleCondition.java b/src/java/org/apache/cassandra/utils/concurrent/SimpleCondition.java index 57614e005caa..0ff90185a5c2 100644 --- a/src/java/org/apache/cassandra/utils/concurrent/SimpleCondition.java +++ b/src/java/org/apache/cassandra/utils/concurrent/SimpleCondition.java @@ -48,10 +48,15 @@ public void await() throws InterruptedException public boolean await(long time, TimeUnit unit) throws InterruptedException { - if (isSignaled()) - return true; long start = System.nanoTime(); long until = start + unit.toNanos(time); + return awaitUntil(until); + } + + public boolean awaitUntil(long deadlineNanos) throws InterruptedException + { + if (isSignaled()) + return true; if (waiting == null) waitingUpdater.compareAndSet(this, null, new WaitQueue()); WaitQueue.Signal s = waiting.register(); @@ -60,7 +65,7 @@ public boolean await(long time, TimeUnit unit) throws InterruptedException s.cancel(); return true; } - return s.awaitUntil(until) || isSignaled(); + return s.awaitUntil(deadlineNanos) || isSignaled(); } public void signal() diff --git a/src/java/org/apache/cassandra/utils/concurrent/WaitQueue.java b/src/java/org/apache/cassandra/utils/concurrent/WaitQueue.java index 5b453b06475a..3647623095de 100644 --- a/src/java/org/apache/cassandra/utils/concurrent/WaitQueue.java +++ b/src/java/org/apache/cassandra/utils/concurrent/WaitQueue.java @@ -263,6 +263,15 @@ public static interface Signal * @throws InterruptedException */ public boolean awaitUntil(long nanos) throws InterruptedException; + + /** + * Wait until signalled, or the provided time is reached, or the thread is interrupted. If signalled, + * isSignalled() will be true on exit, and the method will return true; if timedout, the method will return + * false and isCancelled() will be true + * @param nanos System.nanoTime() to wait until + * @return true if signalled, false if timed out + */ + public boolean awaitUntilUninterruptibly(long nanos); } /** @@ -306,6 +315,17 @@ public boolean awaitUntil(long until) throws InterruptedException return checkAndClear(); } + public boolean awaitUntilUninterruptibly(long until) + { + long now; + while (until > (now = System.nanoTime()) && !isSignalled()) + { + long delta = until - now; + LockSupport.parkNanos(delta); + } + return checkAndClear(); + } + private void checkInterrupted() throws InterruptedException { if (Thread.interrupted()) diff --git a/src/java/org/apache/cassandra/utils/memory/BufferPool.java b/src/java/org/apache/cassandra/utils/memory/BufferPool.java index a67f5208e6d9..c2e81089c8e3 100644 --- a/src/java/org/apache/cassandra/utils/memory/BufferPool.java +++ b/src/java/org/apache/cassandra/utils/memory/BufferPool.java @@ -21,12 +21,21 @@ import java.lang.ref.PhantomReference; import java.lang.ref.ReferenceQueue; import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.ArrayDeque; +import java.util.Collections; import java.util.Queue; +import java.util.Set; import java.util.concurrent.*; import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicLongFieldUpdater; +import java.util.function.BiPredicate; +import java.util.function.Consumer; +import java.util.function.Supplier; import com.google.common.annotations.VisibleForTesting; + +import net.nicoulaj.compilecommand.annotations.Inline; import org.apache.cassandra.concurrent.InfiniteLoopExecutor; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -36,29 +45,41 @@ import org.apache.cassandra.io.compress.BufferType; import org.apache.cassandra.io.util.FileUtils; import org.apache.cassandra.metrics.BufferPoolMetrics; -import org.apache.cassandra.utils.FBUtilities; import org.apache.cassandra.utils.NoSpamLogger; import org.apache.cassandra.utils.concurrent.Ref; +import static com.google.common.collect.ImmutableList.of; +import static org.apache.cassandra.utils.ExecutorUtils.*; +import static org.apache.cassandra.utils.FBUtilities.prettyPrintMemory; + /** * A pool of ByteBuffers that can be recycled. + * + * TODO: document the semantics of this class carefully + * Notably: we do not automatically release from the local pool any chunk that has been incompletely allocated from */ public class BufferPool { - /** The size of a page aligned buffer, 64KiB */ - public static final int CHUNK_SIZE = 64 << 10; + /** The size of a page aligned buffer, 128KiB */ + public static final int NORMAL_CHUNK_SIZE = 128 << 10; + public static final int NORMAL_ALLOCATION_UNIT = NORMAL_CHUNK_SIZE / 64; + public static final int TINY_CHUNK_SIZE = NORMAL_ALLOCATION_UNIT; + public static final int TINY_ALLOCATION_UNIT = TINY_CHUNK_SIZE / 64; + public static final int TINY_ALLOCATION_LIMIT = TINY_CHUNK_SIZE / 2; + private final static BufferPoolMetrics metrics = new BufferPoolMetrics(); + + // TODO: this should not be using FileCacheSizeInMB @VisibleForTesting public static long MEMORY_USAGE_THRESHOLD = DatabaseDescriptor.getFileCacheSizeInMB() * 1024L * 1024L; @VisibleForTesting public static boolean ALLOCATE_ON_HEAP_WHEN_EXAHUSTED = DatabaseDescriptor.getBufferPoolUseHeapIfExhausted(); - @VisibleForTesting - public static boolean DISABLED = Boolean.parseBoolean(System.getProperty("cassandra.test.disable_buffer_pool", "false")); + private static Debug debug; @VisibleForTesting - public static boolean DEBUG = false; + public static boolean DISABLED = Boolean.parseBoolean(System.getProperty("cassandra.test.disable_buffer_pool", "false")); private static final Logger logger = LoggerFactory.getLogger(BufferPool.class); private static final NoSpamLogger noSpamLogger = NoSpamLogger.getLogger(logger, 15L, TimeUnit.MINUTES); @@ -75,6 +96,11 @@ protected LocalPool initialValue() { return new LocalPool(); } + + protected void onRemoval(LocalPool value) + { + value.release(); + } }; public static ByteBuffer get(int size) @@ -82,65 +108,43 @@ public static ByteBuffer get(int size) if (DISABLED) return allocate(size, ALLOCATE_ON_HEAP_WHEN_EXAHUSTED); else - return takeFromPool(size, ALLOCATE_ON_HEAP_WHEN_EXAHUSTED); + return localPool.get().get(size, false, ALLOCATE_ON_HEAP_WHEN_EXAHUSTED); } public static ByteBuffer get(int size, BufferType bufferType) { - boolean direct = bufferType == BufferType.OFF_HEAP; - if (DISABLED || !direct) - return allocate(size, !direct); + boolean onHeap = bufferType == BufferType.ON_HEAP; + if (DISABLED || onHeap) + return allocate(size, onHeap); else - return takeFromPool(size, !direct); + return localPool.get().get(size, false, onHeap); } - /** Unlike the get methods, this will return null if the pool is exhausted */ - public static ByteBuffer tryGet(int size) + public static ByteBuffer getAtLeast(int size, BufferType bufferType) { - if (DISABLED) - return allocate(size, ALLOCATE_ON_HEAP_WHEN_EXAHUSTED); + boolean onHeap = bufferType == BufferType.ON_HEAP; + if (DISABLED || onHeap) + return allocate(size, onHeap); else - return maybeTakeFromPool(size, ALLOCATE_ON_HEAP_WHEN_EXAHUSTED); + return localPool.get().get(size, true, onHeap); } - private static ByteBuffer allocate(int size, boolean onHeap) + /** Unlike the get methods, this will return null if the pool is exhausted */ + public static ByteBuffer tryGet(int size) { - return onHeap - ? ByteBuffer.allocate(size) - : ByteBuffer.allocateDirect(size); + return localPool.get().tryGet(size, true); } - private static ByteBuffer takeFromPool(int size, boolean allocateOnHeapWhenExhausted) + public static ByteBuffer tryGetAtLeast(int size) { - ByteBuffer ret = maybeTakeFromPool(size, allocateOnHeapWhenExhausted); - if (ret != null) - return ret; - - if (logger.isTraceEnabled()) - logger.trace("Requested buffer size {} has been allocated directly due to lack of capacity", FBUtilities.prettyPrintMemory(size)); - - return localPool.get().allocate(size, allocateOnHeapWhenExhausted); + return localPool.get().tryGet(size, true); } - private static ByteBuffer maybeTakeFromPool(int size, boolean allocateOnHeapWhenExhausted) + private static ByteBuffer allocate(int size, boolean onHeap) { - if (size < 0) - throw new IllegalArgumentException("Size must be positive (" + size + ")"); - - if (size == 0) - return EMPTY_BUFFER; - - if (size > CHUNK_SIZE) - { - if (logger.isTraceEnabled()) - logger.trace("Requested buffer size {} is bigger than {}, allocating directly", - FBUtilities.prettyPrintMemory(size), - FBUtilities.prettyPrintMemory(CHUNK_SIZE)); - - return localPool.get().allocate(size, allocateOnHeapWhenExhausted); - } - - return localPool.get().get(size); + return onHeap + ? ByteBuffer.allocate(size) + : ByteBuffer.allocateDirect(size); } public static void put(ByteBuffer buffer) @@ -149,61 +153,43 @@ public static void put(ByteBuffer buffer) localPool.get().put(buffer); } - /** This is not thread safe and should only be used for unit testing. */ - @VisibleForTesting - static void reset() + public static void putUnusedPortion(ByteBuffer buffer) { - localPool.get().reset(); - globalPool.reset(); + + if (!(DISABLED || buffer.hasArray())) + { + LocalPool pool = localPool.get(); + if (buffer.limit() > 0) + pool.putUnusedPortion(buffer); + else + pool.put(buffer); + } } - @VisibleForTesting - static Chunk currentChunk() + public static void setRecycleWhenFreeForCurrentThread(boolean recycleWhenFree) { - return localPool.get().chunks[0]; + localPool.get().recycleWhenFree(recycleWhenFree); } - @VisibleForTesting - static int numChunks() + public static long sizeInBytes() { - int ret = 0; - for (Chunk chunk : localPool.get().chunks) - { - if (chunk != null) - ret++; - } - return ret; + return globalPool.sizeInBytes(); } - @VisibleForTesting - static void assertAllRecycled() + interface Debug { - globalPool.debug.check(); + void registerNormal(Chunk chunk); + void recycleNormal(Chunk oldVersion, Chunk newVersion); } - public static long sizeInBytes() + public static void debug(Debug setDebug) { - return globalPool.sizeInBytes(); + debug = setDebug; } - static final class Debug + interface Recycler { - long recycleRound = 1; - final Queue allChunks = new ConcurrentLinkedQueue<>(); - void register(Chunk chunk) - { - allChunks.add(chunk); - } - void recycle(Chunk chunk) - { - chunk.lastRecycled = recycleRound; - } - void check() - { - for (Chunk chunk : allChunks) - assert chunk.lastRecycled == recycleRound; - recycleRound++; - } + void recycle(Chunk chunk); } /** @@ -213,26 +199,25 @@ void check() * * This class is shared by multiple thread local pools and must be thread-safe. */ - static final class GlobalPool + static final class GlobalPool implements Supplier, Recycler { - /** The size of a bigger chunk, 1-mbit, must be a multiple of CHUNK_SIZE */ - static final int MACRO_CHUNK_SIZE = 1 << 20; + /** The size of a bigger chunk, 1 MiB, must be a multiple of NORMAL_CHUNK_SIZE */ + static final int MACRO_CHUNK_SIZE = 64 * NORMAL_CHUNK_SIZE; static { - assert Integer.bitCount(CHUNK_SIZE) == 1; // must be a power of 2 + assert Integer.bitCount(NORMAL_CHUNK_SIZE) == 1; // must be a power of 2 assert Integer.bitCount(MACRO_CHUNK_SIZE) == 1; // must be a power of 2 - assert MACRO_CHUNK_SIZE % CHUNK_SIZE == 0; // must be a multiple + assert MACRO_CHUNK_SIZE % NORMAL_CHUNK_SIZE == 0; // must be a multiple if (DISABLED) logger.info("Global buffer pool is disabled, allocating {}", ALLOCATE_ON_HEAP_WHEN_EXAHUSTED ? "on heap" : "off heap"); else logger.info("Global buffer pool is enabled, when pool is exhausted (max is {}) it will allocate {}", - FBUtilities.prettyPrintMemory(MEMORY_USAGE_THRESHOLD), + prettyPrintMemory(MEMORY_USAGE_THRESHOLD), ALLOCATE_ON_HEAP_WHEN_EXAHUSTED ? "on heap" : "off heap"); } - private final Debug debug = new Debug(); private final Queue macroChunks = new ConcurrentLinkedQueue<>(); // TODO (future): it would be preferable to use a CLStack to improve cache occupancy; it would also be preferable to use "CoreLocal" storage private final Queue chunks = new ConcurrentLinkedQueue<>(); @@ -265,7 +250,8 @@ private Chunk allocateMoreChunks() if (cur + MACRO_CHUNK_SIZE > MEMORY_USAGE_THRESHOLD) { noSpamLogger.info("Maximum memory usage reached ({}), cannot allocate chunk of {}", - MEMORY_USAGE_THRESHOLD, MACRO_CHUNK_SIZE); + prettyPrintMemory(MEMORY_USAGE_THRESHOLD), + prettyPrintMemory(MACRO_CHUNK_SIZE)); return null; } if (memoryUsage.compareAndSet(cur, cur + MACRO_CHUNK_SIZE)) @@ -276,36 +262,41 @@ private Chunk allocateMoreChunks() Chunk chunk; try { - chunk = new Chunk(allocateDirectAligned(MACRO_CHUNK_SIZE)); + chunk = new Chunk(null, allocateDirectAligned(MACRO_CHUNK_SIZE)); } catch (OutOfMemoryError oom) { noSpamLogger.error("Buffer pool failed to allocate chunk of {}, current size {} ({}). " + "Attempting to continue; buffers will be allocated in on-heap memory which can degrade performance. " + "Make sure direct memory size (-XX:MaxDirectMemorySize) is large enough to accommodate off-heap memtables and caches.", - MACRO_CHUNK_SIZE, sizeInBytes(), oom.toString()); + prettyPrintMemory(MACRO_CHUNK_SIZE), + prettyPrintMemory(sizeInBytes()), + oom.toString()); return null; } chunk.acquire(null); macroChunks.add(chunk); - final Chunk callerChunk = new Chunk(chunk.get(CHUNK_SIZE)); - if (DEBUG) - debug.register(callerChunk); - for (int i = CHUNK_SIZE ; i < MACRO_CHUNK_SIZE; i += CHUNK_SIZE) + final Chunk callerChunk = new Chunk(this, chunk.get(NORMAL_CHUNK_SIZE)); + if (debug != null) + debug.registerNormal(callerChunk); + for (int i = NORMAL_CHUNK_SIZE; i < MACRO_CHUNK_SIZE; i += NORMAL_CHUNK_SIZE) { - Chunk add = new Chunk(chunk.get(CHUNK_SIZE)); + Chunk add = new Chunk(this, chunk.get(NORMAL_CHUNK_SIZE)); chunks.add(add); - if (DEBUG) - debug.register(add); + if (debug != null) + debug.registerNormal(add); } return callerChunk; } public void recycle(Chunk chunk) { - chunks.add(chunk); + Chunk recycleAs = new Chunk(chunk); + if (debug != null) + debug.recycleNormal(chunk, recycleAs); + chunks.add(recycleAs); } public long sizeInBytes() @@ -315,25 +306,21 @@ public long sizeInBytes() /** This is not thread safe and should only be used for unit testing. */ @VisibleForTesting - void reset() + void unsafeFree() { while (!chunks.isEmpty()) - chunks.poll().reset(); + chunks.poll().unsafeFree(); while (!macroChunks.isEmpty()) - macroChunks.poll().reset(); + macroChunks.poll().unsafeFree(); memoryUsage.set(0); } } - /** - * A thread local class that grabs chunks from the global pool for this thread allocations. - * Only one thread can do the allocations but multiple threads can release the allocations. - */ - static final class LocalPool + private static class MicroQueueOfChunks { - private final static BufferPoolMetrics metrics = new BufferPoolMetrics(); + // a microqueue of Chunks: // * if any are null, they are at the end; // * new Chunks are added to the last null index @@ -341,92 +328,284 @@ static final class LocalPool // * this results in a queue that will typically be visited in ascending order of available space, so that // small allocations preferentially slice from the Chunks with the smallest space available to furnish them // WARNING: if we ever change the size of this, we must update removeFromLocalQueue, and addChunk - private final Chunk[] chunks = new Chunk[3]; - private byte chunkCount = 0; + private Chunk chunk0, chunk1, chunk2; + private int count; - public LocalPool() + // add a new chunk, if necessary evicting the chunk with the least available memory (returning the evicted chunk) + private Chunk add(Chunk chunk) { - localPoolReferences.add(new LocalPoolRef(this, localPoolRefQueue)); + switch (count) + { + case 0: + chunk0 = chunk; + count = 1; + break; + case 1: + chunk1 = chunk; + count = 2; + break; + case 2: + chunk2 = chunk; + count = 3; + break; + case 3: + { + Chunk release; + int chunk0Free = chunk0.freeSlotCount(); + int chunk1Free = chunk1.freeSlotCount(); + int chunk2Free = chunk2.freeSlotCount(); + if (chunk0Free < chunk1Free) + { + if (chunk0Free < chunk2Free) + { + release = chunk0; + chunk0 = chunk; + } + else + { + release = chunk2; + chunk2 = chunk; + } + } + else + { + if (chunk1Free < chunk2Free) + { + release = chunk1; + chunk1 = chunk; + } + else + { + release = chunk2; + chunk2 = chunk; + } + } + return release; + } + default: + throw new IllegalStateException(); + } + return null; } - private Chunk addChunkFromGlobalPool() + private void remove(Chunk chunk) { - Chunk chunk = globalPool.get(); - if (chunk == null) - return null; + // since we only have three elements in the queue, it is clearer, easier and faster to just hard code the options + if (chunk0 == chunk) + { // remove first by shifting back second two + chunk0 = chunk1; + chunk1 = chunk2; + } + else if (chunk1 == chunk) + { // remove second by shifting back last + chunk1 = chunk2; + } + else if (chunk2 != chunk) + { + return; + } + // whatever we do, the last element must be null + chunk2 = null; + --count; + } - addChunk(chunk); - return chunk; + ByteBuffer get(int size, boolean sizeIsLowerBound, ByteBuffer reuse) + { + ByteBuffer buffer; + if (null != chunk0) + { + if (null != (buffer = chunk0.get(size, sizeIsLowerBound, reuse))) + return buffer; + if (null != chunk1) + { + if (null != (buffer = chunk1.get(size, sizeIsLowerBound, reuse))) + return buffer; + if (null != chunk2 && null != (buffer = chunk2.get(size, sizeIsLowerBound, reuse))) + return buffer; + } + } + return null; } - private void addChunk(Chunk chunk) + private void forEach(Consumer consumer) { - chunk.acquire(this); + forEach(consumer, count, chunk0, chunk1, chunk2); + } - if (chunkCount < 3) + private void clearForEach(Consumer consumer) + { + Chunk chunk0 = this.chunk0, chunk1 = this.chunk1, chunk2 = this.chunk2; + this.chunk0 = this.chunk1 = this.chunk2 = null; + forEach(consumer, count, chunk0, chunk1, chunk2); + count = 0; + } + + private static void forEach(Consumer consumer, int count, Chunk chunk0, Chunk chunk1, Chunk chunk2) + { + switch (count) { - chunks[chunkCount++] = chunk; - return; + case 3: + consumer.accept(chunk2); + case 2: + consumer.accept(chunk1); + case 1: + consumer.accept(chunk0); } + } - int smallestChunkIdx = 0; - if (chunks[1].free() < chunks[0].free()) - smallestChunkIdx = 1; - if (chunks[2].free() < chunks[smallestChunkIdx].free()) - smallestChunkIdx = 2; + private void removeIf(BiPredicate predicate, T value) + { + switch (count) + { + case 3: + if (predicate.test(chunk2, value)) + { + --count; + Chunk chunk = chunk2; + chunk2 = null; + chunk.release(); + } + case 2: + if (predicate.test(chunk1, value)) + { + --count; + Chunk chunk = chunk1; + chunk1 = null; + chunk.release(); + } + case 1: + if (predicate.test(chunk0, value)) + { + --count; + Chunk chunk = chunk0; + chunk0 = null; + chunk.release(); + } + break; + case 0: + return; + } + switch (count) + { + case 2: + // Find the only null item, and shift non-null so that null is at chunk2 + if (chunk0 == null) + { + chunk0 = chunk1; + chunk1 = chunk2; + chunk2 = null; + } + else if (chunk1 == null) + { + chunk1 = chunk2; + chunk2 = null; + } + break; + case 1: + // Find the only non-null item, and shift it to chunk0 + if (chunk1 != null) + { + chunk0 = chunk1; + chunk1 = null; + } + else if (chunk2 != null) + { + chunk0 = chunk2; + chunk2 = null; + } + break; + } + } - chunks[smallestChunkIdx].release(); - if (smallestChunkIdx != 2) - chunks[smallestChunkIdx] = chunks[2]; - chunks[2] = chunk; + private void release() + { + clearForEach(Chunk::release); } - public ByteBuffer get(int size) + private void unsafeRecycle() { - for (Chunk chunk : chunks) - { // first see if our own chunks can serve this buffer - if (chunk == null) - break; + clearForEach(Chunk::unsafeRecycle); + } + } - ByteBuffer buffer = chunk.get(size); - if (buffer != null) - return buffer; - } + /** + * A thread local class that grabs chunks from the global pool for this thread allocations. + * Only one thread can do the allocations but multiple threads can release the allocations. + */ + public static final class LocalPool implements Recycler + { + private final Queue reuseObjects; + private final Supplier parent; + private final LocalPoolRef leakRef; - // else ask the global pool - Chunk chunk = addChunkFromGlobalPool(); - if (chunk != null) - return chunk.get(size); + private final MicroQueueOfChunks chunks = new MicroQueueOfChunks(); + /** + * If we are on outer LocalPool, whose chunks are == NORMAL_CHUNK_SIZE, we may service allocation requests + * for buffers much smaller than + */ + private LocalPool tinyPool; + private final int tinyLimit; + private boolean recycleWhenFree = true; - return null; + public LocalPool() + { + this.parent = globalPool; + this.tinyLimit = TINY_ALLOCATION_LIMIT; + this.reuseObjects = new ArrayDeque<>(); + localPoolReferences.add(leakRef = new LocalPoolRef(this, localPoolRefQueue)); } - private ByteBuffer allocate(int size, boolean onHeap) + /** + * Invoked by an existing LocalPool, to create a child pool + */ + private LocalPool(LocalPool parent) { - metrics.misses.mark(); - return BufferPool.allocate(size, onHeap); + this.parent = () -> { + ByteBuffer buffer = parent.tryGetInternal(TINY_CHUNK_SIZE, false); + if (buffer == null) + return null; + return new Chunk(parent, buffer); + }; + this.tinyLimit = 0; // we only currently permit one layer of nesting (which brings us down to 32 byte allocations, so is plenty) + this.reuseObjects = parent.reuseObjects; // we share the same ByteBuffer object reuse pool, as we both have the same exclusive access to it + localPoolReferences.add(leakRef = new LocalPoolRef(this, localPoolRefQueue)); + } + + private LocalPool tinyPool() + { + if (tinyPool == null) + tinyPool = new LocalPool(this).recycleWhenFree(recycleWhenFree); + return tinyPool; } public void put(ByteBuffer buffer) { Chunk chunk = Chunk.getParentChunk(buffer); if (chunk == null) - { FileUtils.clean(buffer); + else + put(buffer, chunk); + } + + public void put(ByteBuffer buffer, Chunk chunk) + { + LocalPool owner = chunk.owner; + if (owner != null && owner == tinyPool) + { + tinyPool.put(buffer, chunk); return; } - LocalPool owner = chunk.owner; // ask the free method to take exclusive ownership of the act of recycling // if we are either: already not owned by anyone, or owned by ourselves - long free = chunk.free(buffer, owner == null | owner == this); + long free = chunk.free(buffer, owner == null || (owner == this && recycleWhenFree)); if (free == 0L) { // 0L => we own recycling responsibility, so must recycle; - chunk.recycle(); - // if we are also the owner, we must remove the Chunk from our local queue + // if we are the owner, we must remove the Chunk from our local queue if (owner == this) - removeFromLocalQueue(chunk); + remove(chunk); + chunk.recycle(); } else if (((free == -1L) && owner != this) && chunk.owner == null) { @@ -434,46 +613,186 @@ else if (((free == -1L) && owner != this) && chunk.owner == null) // we must also check after completely freeing if the owner has since been unset, and try to recycle chunk.tryRecycle(); } + + if (owner == this) + { + MemoryUtil.setAttachment(buffer, null); + MemoryUtil.setDirectByteBuffer(buffer, 0, 0); + reuseObjects.add(buffer); + } } - private void removeFromLocalQueue(Chunk chunk) + public void putUnusedPortion(ByteBuffer buffer) { - // since we only have three elements in the queue, it is clearer, easier and faster to just hard code the options - if (chunks[0] == chunk) - { // remove first by shifting back second two - chunks[0] = chunks[1]; - chunks[1] = chunks[2]; + Chunk chunk = Chunk.getParentChunk(buffer); + if (chunk == null) + return; + + chunk.freeUnusedPortion(buffer); + } + + public ByteBuffer get(int size) + { + return get(size, ALLOCATE_ON_HEAP_WHEN_EXAHUSTED); + } + + public ByteBuffer get(int size, boolean allocateOnHeapWhenExhausted) + { + return get(size, false, allocateOnHeapWhenExhausted); + } + + public ByteBuffer getAtLeast(int size) + { + return getAtLeast(size, ALLOCATE_ON_HEAP_WHEN_EXAHUSTED); + } + + public ByteBuffer getAtLeast(int size, boolean allocateOnHeapWhenExhausted) + { + return get(size, true, allocateOnHeapWhenExhausted); + } + + private ByteBuffer get(int size, boolean sizeIsLowerBound, boolean allocateOnHeapWhenExhausted) + { + ByteBuffer ret = tryGet(size, sizeIsLowerBound); + if (ret != null) + return ret; + + if (size > NORMAL_CHUNK_SIZE) + { + if (logger.isTraceEnabled()) + logger.trace("Requested buffer size {} is bigger than {}; allocating directly", + prettyPrintMemory(size), + prettyPrintMemory(NORMAL_CHUNK_SIZE)); } - else if (chunks[1] == chunk) - { // remove second by shifting back last - chunks[1] = chunks[2]; + else + { + if (logger.isTraceEnabled()) + logger.trace("Requested buffer size {} has been allocated directly due to lack of capacity", prettyPrintMemory(size)); } - else assert chunks[2] == chunk; - // whatever we do, the last element myst be null - chunks[2] = null; - chunkCount--; + + metrics.misses.mark(); + return allocate(size, allocateOnHeapWhenExhausted); } - @VisibleForTesting - void reset() + public ByteBuffer tryGet(int size) + { + return tryGet(size, ALLOCATE_ON_HEAP_WHEN_EXAHUSTED); + } + + public ByteBuffer tryGetAtLeast(int size) { - chunkCount = 0; - for (int i = 0; i < chunks.length; i++) + return tryGet(size, true); + } + + private ByteBuffer tryGet(int size, boolean sizeIsLowerBound) + { + LocalPool pool = this; + if (size <= tinyLimit) { - if (chunks[i] != null) + if (size <= 0) { - chunks[i].owner = null; - chunks[i].freeSlots = 0L; - chunks[i].recycle(); - chunks[i] = null; + if (size == 0) + return EMPTY_BUFFER; + throw new IllegalArgumentException("Size must be non-negative (" + size + ')'); } + + pool = tinyPool(); } + else if (size > NORMAL_CHUNK_SIZE) + { + return null; + } + + return pool.tryGetInternal(size, sizeIsLowerBound); + } + + @Inline + private ByteBuffer tryGetInternal(int size, boolean sizeIsLowerBound) + { + ByteBuffer reuse = this.reuseObjects.poll(); + ByteBuffer buffer = chunks.get(size, sizeIsLowerBound, reuse); + if (buffer != null) + return buffer; + + // else ask the global pool + Chunk chunk = addChunkFromParent(); + if (chunk != null) + { + ByteBuffer result = chunk.get(size, sizeIsLowerBound, reuse); + if (result != null) + return result; + } + + if (reuse != null) + this.reuseObjects.add(reuse); + return null; + } + + // recycle + public void recycle(Chunk chunk) + { + ByteBuffer buffer = chunk.slab; + Chunk parentChunk = Chunk.getParentChunk(buffer); + put(buffer, parentChunk); + } + + private void remove(Chunk chunk) + { + chunks.remove(chunk); + if (tinyPool != null) + tinyPool.chunks.removeIf((child, parent) -> Chunk.getParentChunk(child.slab) == parent, chunk); + } + + private Chunk addChunkFromParent() + { + Chunk chunk = parent.get(); + if (chunk == null) + return null; + + addChunk(chunk); + return chunk; + } + + private void addChunk(Chunk chunk) + { + chunk.acquire(this); + Chunk evict = chunks.add(chunk); + if (evict != null) + { + if (tinyPool != null) + tinyPool.chunks.removeIf((child, parent) -> Chunk.getParentChunk(child.slab) == parent, evict); + evict.release(); + } + } + + public void release() + { + chunks.release(); + reuseObjects.clear(); + localPoolReferences.remove(leakRef); + leakRef.clear(); + if (tinyPool != null) + tinyPool.release(); + } + + @VisibleForTesting + void unsafeRecycle() + { + chunks.unsafeRecycle(); + } + + public LocalPool recycleWhenFree(boolean recycleWhenFree) + { + this.recycleWhenFree = recycleWhenFree; + if (tinyPool != null) + tinyPool.recycleWhenFree = recycleWhenFree; + return this; } } - private static final class LocalPoolRef extends PhantomReference + private static final class LocalPoolRef extends PhantomReference { - private final Chunk[] chunks; + private final MicroQueueOfChunks chunks; public LocalPoolRef(LocalPool localPool, ReferenceQueue q) { super(localPool, q); @@ -482,18 +801,11 @@ public LocalPoolRef(LocalPool localPool, ReferenceQueue q) public void release() { - for (int i = 0 ; i < chunks.length ; i++) - { - if (chunks[i] != null) - { - chunks[i].release(); - chunks[i] = null; - } - } + chunks.release(); } } - private static final ConcurrentLinkedQueue localPoolReferences = new ConcurrentLinkedQueue<>(); + private static final Set localPoolReferences = Collections.newSetFromMap(new ConcurrentHashMap<>()); private static final ReferenceQueue localPoolRefQueue = new ReferenceQueue<>(); private static final InfiniteLoopExecutor EXEC = new InfiniteLoopExecutor("LocalPool-Cleaner", BufferPool::cleanupOneReference).start(); @@ -561,8 +873,10 @@ final static class Chunk // if this is set, it means the chunk may not be recycled because we may still allocate from it; // if it has been unset the local pool has finished with it, and it may be recycled private volatile LocalPool owner; - private long lastRecycled; - private final Chunk original; + private final Recycler recycler; + + @VisibleForTesting + Object debugAttachment; Chunk(Chunk recycle) { @@ -571,14 +885,13 @@ final static class Chunk this.baseAddress = recycle.baseAddress; this.shift = recycle.shift; this.freeSlots = -1L; - this.original = recycle.original; - if (DEBUG) - globalPool.debug.recycle(original); + this.recycler = recycle.recycler; } - Chunk(ByteBuffer slab) + Chunk(Recycler recycler, ByteBuffer slab) { assert !slab.hasArray(); + this.recycler = recycler; this.slab = slab; this.baseAddress = MemoryUtil.getAddress(slab); @@ -587,7 +900,6 @@ final static class Chunk this.shift = 31 & (Integer.numberOfTrailingZeros(slab.capacity() / 64)); // -1 means all free whilst 0 means all in use this.freeSlots = slab.capacity() == 0 ? 0L : -1L; - this.original = DEBUG ? this : null; } /** @@ -621,7 +933,7 @@ void tryRecycle() void recycle() { assert freeSlots == 0L; - globalPool.recycle(new Chunk(this)); + recycler.recycle(this); } /** @@ -642,14 +954,12 @@ static Chunk getParentChunk(ByteBuffer buffer) return null; } - ByteBuffer setAttachment(ByteBuffer buffer) + void setAttachment(ByteBuffer buffer) { if (Ref.DEBUG_ENABLED) MemoryUtil.setAttachment(buffer, new Ref<>(this, null)); else MemoryUtil.setAttachment(buffer, this); - - return buffer; } boolean releaseAttachment(ByteBuffer buffer) @@ -658,22 +968,12 @@ boolean releaseAttachment(ByteBuffer buffer) if (attachment == null) return false; - if (attachment instanceof Ref) + if (Ref.DEBUG_ENABLED) ((Ref) attachment).release(); return true; } - @VisibleForTesting - void reset() - { - Chunk parent = getParentChunk(slab); - if (parent != null) - parent.free(slab, false); - else - FileUtils.clean(slab); - } - @VisibleForTesting long setFreeSlots(long val) { @@ -703,15 +1003,27 @@ int free() return Long.bitCount(freeSlots) * unit(); } + int freeSlotCount() + { + return Long.bitCount(freeSlots); + } + + ByteBuffer get(int size) + { + return get(size, false, null); + } + /** * Return the next available slice of this size. If * we have exceeded the capacity we return null. */ - ByteBuffer get(int size) + ByteBuffer get(int size, boolean sizeIsLowerBound, ByteBuffer into) { // how many multiples of our units is the size? // we add (unit - 1), so that when we divide by unit (>>> shift), we effectively round up int slotCount = (size - 1 + unit()) >>> shift; + if (sizeIsLowerBound) + size = slotCount << shift; // if we require more than 64 slots, we cannot possibly accommodate the allocation if (slotCount > 64) @@ -775,17 +1087,18 @@ ByteBuffer get(int size) // make sure no other thread has cleared the candidate bits assert ((candidate & cur) == candidate); } - return get(index << shift, size); + return set(index << shift, size, into); } } } - private ByteBuffer get(int offset, int size) + private ByteBuffer set(int offset, int size, ByteBuffer into) { - slab.limit(offset + size); - slab.position(offset); - - return setAttachment(slab.slice()); + if (into == null) + into = MemoryUtil.getHollowDirectByteBuffer(ByteOrder.BIG_ENDIAN); + MemoryUtil.sliceDirectByteBuffer(slab, into, offset, size); + setAttachment(into); + return into; } /** @@ -807,25 +1120,16 @@ long free(ByteBuffer buffer, boolean tryRelease) if (!releaseAttachment(buffer)) return 1L; + int size = roundUp(buffer.capacity()); long address = MemoryUtil.getAddress(buffer); - assert (address >= baseAddress) & (address <= baseAddress + capacity()); + assert (address >= baseAddress) & (address + size <= baseAddress + capacity()); - int position = (int)(address - baseAddress); - int size = roundUp(buffer.capacity()); + int position = ((int)(address - baseAddress)) >> shift; - position >>= shift; int slotCount = size >> shift; - - long slotBits = (1L << slotCount) - 1; + long slotBits = 0xffffffffffffffffL >>> (64 - slotCount); long shiftedSlotBits = (slotBits << position); - if (slotCount == 64) - { - assert size == capacity(); - assert position == 0; - shiftedSlotBits = -1L; - } - long next; while (true) { @@ -839,29 +1143,121 @@ long free(ByteBuffer buffer, boolean tryRelease) } } + void freeUnusedPortion(ByteBuffer buffer) + { + int size = roundUp(buffer.limit()); + int capacity = roundUp(buffer.capacity()); + if (size == capacity) + return; + + long address = MemoryUtil.getAddress(buffer); + assert (address >= baseAddress) & (address + size <= baseAddress + capacity()); + + // free any spare slots above the size we are using + int position = ((int)(address + size - baseAddress)) >> shift; + int slotCount = (capacity - size) >> shift; + + long slotBits = 0xffffffffffffffffL >>> (64 - slotCount); + long shiftedSlotBits = (slotBits << position); + + long next; + while (true) + { + long cur = freeSlots; + next = cur | shiftedSlotBits; + assert next == (cur ^ shiftedSlotBits); // ensure no double free + if (freeSlotsUpdater.compareAndSet(this, cur, next)) + break; + } + MemoryUtil.setByteBufferCapacity(buffer, size); + } + @Override public String toString() { return String.format("[slab %s, slots bitmap %s, capacity %d, free %d]", slab, Long.toBinaryString(freeSlots), capacity(), free()); } + + @VisibleForTesting + void unsafeFree() + { + Chunk parent = getParentChunk(slab); + if (parent != null) + parent.free(slab, false); + else + FileUtils.clean(slab); + } + + static void unsafeRecycle(Chunk chunk) + { + if (chunk != null) + { + chunk.owner = null; + chunk.freeSlots = 0L; + chunk.recycle(); + } + } } @VisibleForTesting - public static int roundUpNormal(int size) + public static int roundUp(int size) { - return roundUp(size, CHUNK_SIZE / 64); + if (size <= TINY_ALLOCATION_LIMIT) + return roundUp(size, TINY_ALLOCATION_UNIT); + return roundUp(size, NORMAL_ALLOCATION_UNIT); } - private static int roundUp(int size, int unit) + @VisibleForTesting + public static int roundUp(int size, int unit) { int mask = unit - 1; return (size + mask) & ~mask; } @VisibleForTesting - public static void shutdownLocalCleaner() throws InterruptedException + public static void shutdownLocalCleaner(long timeout, TimeUnit unit) throws InterruptedException, TimeoutException { - EXEC.shutdown(); - EXEC.awaitTermination(60, TimeUnit.SECONDS); + shutdownNow(of(EXEC)); + awaitTermination(timeout, unit, of(EXEC)); } + + public static long unsafeGetBytesInUse() + { + long totalMemory = globalPool.memoryUsage.get(); + class L { long v; } + final L availableMemory = new L(); + for (Chunk chunk : globalPool.chunks) + { + availableMemory.v += chunk.capacity(); + } + for (LocalPoolRef ref : localPoolReferences) + { + ref.chunks.forEach(chunk -> availableMemory.v += chunk.free()); + } + return totalMemory - availableMemory.v; + } + + /** This is not thread safe and should only be used for unit testing. */ + @VisibleForTesting + static void unsafeReset() + { + localPool.get().unsafeRecycle(); + globalPool.unsafeFree(); + } + + @VisibleForTesting + static Chunk unsafeCurrentChunk() + { + return localPool.get().chunks.chunk0; + } + + @VisibleForTesting + static int unsafeNumChunks() + { + LocalPool pool = localPool.get(); + return (pool.chunks.chunk0 != null ? 1 : 0) + + (pool.chunks.chunk1 != null ? 1 : 0) + + (pool.chunks.chunk2 != null ? 1 : 0); + } + } diff --git a/src/java/org/apache/cassandra/utils/memory/MemoryUtil.java b/src/java/org/apache/cassandra/utils/memory/MemoryUtil.java index 998cbbf1a0d7..6b30f44aa6cb 100644 --- a/src/java/org/apache/cassandra/utils/memory/MemoryUtil.java +++ b/src/java/org/apache/cassandra/utils/memory/MemoryUtil.java @@ -157,7 +157,7 @@ public static ByteBuffer getByteBuffer(long address, int length) public static ByteBuffer getByteBuffer(long address, int length, ByteOrder order) { ByteBuffer instance = getHollowDirectByteBuffer(order); - setByteBuffer(instance, address, length); + setDirectByteBuffer(instance, address, length); return instance; } @@ -196,13 +196,6 @@ public static ByteBuffer getHollowByteBuffer() return instance; } - public static void setByteBuffer(ByteBuffer instance, long address, int length) - { - unsafe.putLong(instance, DIRECT_BYTE_BUFFER_ADDRESS_OFFSET, address); - unsafe.putInt(instance, DIRECT_BYTE_BUFFER_CAPACITY_OFFSET, length); - unsafe.putInt(instance, DIRECT_BYTE_BUFFER_LIMIT_OFFSET, length); - } - public static Object getAttachment(ByteBuffer instance) { assert instance.getClass() == DIRECT_BYTE_BUFFER_CLASS; @@ -225,6 +218,26 @@ public static ByteBuffer duplicateDirectByteBuffer(ByteBuffer source, ByteBuffer return hollowBuffer; } + public static ByteBuffer sliceDirectByteBuffer(ByteBuffer source, ByteBuffer hollowBuffer, int offset, int length) + { + assert source.getClass() == DIRECT_BYTE_BUFFER_CLASS || source.getClass() == RO_DIRECT_BYTE_BUFFER_CLASS; + setDirectByteBuffer(hollowBuffer, offset + unsafe.getLong(source, DIRECT_BYTE_BUFFER_ADDRESS_OFFSET), length); + return hollowBuffer; + } + + public static void setDirectByteBuffer(ByteBuffer instance, long address, int length) + { + unsafe.putLong(instance, DIRECT_BYTE_BUFFER_ADDRESS_OFFSET, address); + unsafe.putInt(instance, DIRECT_BYTE_BUFFER_POSITION_OFFSET, 0); + unsafe.putInt(instance, DIRECT_BYTE_BUFFER_CAPACITY_OFFSET, length); + unsafe.putInt(instance, DIRECT_BYTE_BUFFER_LIMIT_OFFSET, length); + } + + public static void setByteBufferCapacity(ByteBuffer instance, int capacity) + { + unsafe.putInt(instance, DIRECT_BYTE_BUFFER_CAPACITY_OFFSET, capacity); + } + public static long getLongByByte(long address) { if (BIG_ENDIAN) diff --git a/src/java/org/apache/cassandra/utils/memory/MemtablePool.java b/src/java/org/apache/cassandra/utils/memory/MemtablePool.java index 684db93bb36d..5ef023f74142 100644 --- a/src/java/org/apache/cassandra/utils/memory/MemtablePool.java +++ b/src/java/org/apache/cassandra/utils/memory/MemtablePool.java @@ -19,6 +19,7 @@ package org.apache.cassandra.utils.memory; import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicLongFieldUpdater; import com.google.common.annotations.VisibleForTesting; @@ -27,6 +28,7 @@ import org.apache.cassandra.metrics.CassandraMetricsRegistry; import org.apache.cassandra.metrics.DefaultNameFactory; import org.apache.cassandra.utils.concurrent.WaitQueue; +import org.apache.cassandra.utils.ExecutorUtils; /** @@ -66,10 +68,10 @@ MemtableCleanerThread getCleaner(Runnable cleaner) return cleaner == null ? null : new MemtableCleanerThread<>(this, cleaner); } - public void shutdown() throws InterruptedException + @VisibleForTesting + public void shutdownAndWait(long timeout, TimeUnit unit) throws InterruptedException, TimeoutException { - cleaner.shutdown(); - cleaner.awaitTermination(60, TimeUnit.SECONDS); + ExecutorUtils.shutdownNowAndWait(timeout, unit, cleaner); } public abstract MemtableAllocator newAllocator(); diff --git a/src/java/org/apache/cassandra/utils/vint/VIntCoding.java b/src/java/org/apache/cassandra/utils/vint/VIntCoding.java index d4fdd630ea5d..6961d9f19148 100644 --- a/src/java/org/apache/cassandra/utils/vint/VIntCoding.java +++ b/src/java/org/apache/cassandra/utils/vint/VIntCoding.java @@ -49,10 +49,11 @@ import java.io.DataInput; import java.io.DataOutput; import java.io.IOException; +import java.nio.ByteBuffer; -import io.netty.buffer.ByteBuf; import io.netty.util.concurrent.FastThreadLocal; import net.nicoulaj.compilecommand.annotations.Inline; +import org.apache.cassandra.io.util.DataInputPlus; /** * Borrows idea from @@ -82,37 +83,45 @@ public static long readUnsignedVInt(DataInput input) throws IOException return retval; } + public static void skipUnsignedVInt(DataInputPlus input) throws IOException + { + int firstByte = input.readByte(); + if (firstByte < 0) + input.skipBytesFully(numberOfExtraBytesToRead(firstByte)); + } + /** * Note this method is the same as {@link #readUnsignedVInt(DataInput)}, * except that we do *not* block if there are not enough bytes in the buffer * to reconstruct the value. * + * WARNING: this method is only safe for vints we know to be representable by a positive long value. + * * @return -1 if there are not enough bytes in the input to read the value; else, the vint unsigned value. */ - public static long readUnsignedVInt(ByteBuf input) + public static long getUnsignedVInt(ByteBuffer input, int readerIndex) { - if (!input.isReadable()) + return getUnsignedVInt(input, readerIndex, input.limit()); + } + public static long getUnsignedVInt(ByteBuffer input, int readerIndex, int readerLimit) + { + if (readerIndex >= readerLimit) return -1; - input.markReaderIndex(); - int firstByte = input.readByte(); + int firstByte = input.get(readerIndex++); //Bail out early if this is one byte, necessary or it fails later if (firstByte >= 0) return firstByte; int size = numberOfExtraBytesToRead(firstByte); - - if (input.readableBytes() < size) - { - input.resetReaderIndex(); + if (readerIndex + size > readerLimit) return -1; - } long retval = firstByte & firstByteValueMask(size); for (int ii = 0; ii < size; ii++) { - byte b = input.readByte(); + byte b = input.get(readerIndex++); retval <<= 8; retval |= b & 0xff; } @@ -120,6 +129,24 @@ public static long readUnsignedVInt(ByteBuf input) return retval; } + /** + * Computes size of an unsigned vint that starts at readerIndex of the provided ByteBuf. + * + * @return -1 if there are not enough bytes in the input to calculate the size; else, the vint unsigned value size in bytes. + */ + public static int computeUnsignedVIntSize(ByteBuffer input, int readerIndex) + { + return computeUnsignedVIntSize(input, readerIndex, input.limit()); + } + public static int computeUnsignedVIntSize(ByteBuffer input, int readerIndex, int readerLimit) + { + if (readerIndex >= readerLimit) + return -1; + + int firstByte = input.get(readerIndex); + return 1 + ((firstByte >= 0) ? 0 : numberOfExtraBytesToRead(firstByte)); + } + public static long readVInt(DataInput input) throws IOException { return decodeZigZag64(readUnsignedVInt(input)); @@ -155,6 +182,7 @@ public byte[] initialValue() } }; + @Inline public static void writeUnsignedVInt(long value, DataOutput output) throws IOException { int size = VIntCoding.computeUnsignedVIntSize(value); @@ -164,24 +192,47 @@ public static void writeUnsignedVInt(long value, DataOutput output) throws IOExc return; } - output.write(VIntCoding.encodeVInt(value, size), 0, size); + output.write(VIntCoding.encodeUnsignedVInt(value, size), 0, size); } @Inline - public static byte[] encodeVInt(long value, int size) + public static void writeUnsignedVInt(long value, ByteBuffer output) { - byte encodingSpace[] = encodingBuffer.get(); - int extraBytes = size - 1; + int size = VIntCoding.computeUnsignedVIntSize(value); + if (size == 1) + { + output.put((byte) value); + return; + } + + output.put(VIntCoding.encodeUnsignedVInt(value, size), 0, size); + } + /** + * @return a TEMPORARY THREAD LOCAL BUFFER containing the encoded bytes of the value + * This byte[] must be discarded by the caller immediately, and synchronously + */ + @Inline + private static byte[] encodeUnsignedVInt(long value, int size) + { + byte[] encodingSpace = encodingBuffer.get(); + encodeUnsignedVInt(value, size, encodingSpace); + return encodingSpace; + } + + @Inline + private static void encodeUnsignedVInt(long value, int size, byte[] encodeInto) + { + int extraBytes = size - 1; for (int i = extraBytes ; i >= 0; --i) { - encodingSpace[i] = (byte) value; + encodeInto[i] = (byte) value; value >>= 8; } - encodingSpace[0] |= VIntCoding.encodeExtraBytesToRead(extraBytes); - return encodingSpace; + encodeInto[0] |= VIntCoding.encodeExtraBytesToRead(extraBytes); } + @Inline public static void writeVInt(long value, DataOutput output) throws IOException { writeUnsignedVInt(encodeZigZag64(value), output); diff --git a/test/burn/org/apache/cassandra/net/BytesInFlightController.java b/test/burn/org/apache/cassandra/net/BytesInFlightController.java new file mode 100644 index 000000000000..edd9a7e1b91e --- /dev/null +++ b/test/burn/org/apache/cassandra/net/BytesInFlightController.java @@ -0,0 +1,174 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.net; + +import java.util.Map; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.ConcurrentSkipListMap; +import java.util.concurrent.atomic.AtomicLongFieldUpdater; +import java.util.concurrent.locks.LockSupport; +import java.util.function.IntConsumer; + +import org.apache.cassandra.utils.Pair; + +public class BytesInFlightController +{ + private static final AtomicLongFieldUpdater sentBytesUpdater = AtomicLongFieldUpdater.newUpdater(BytesInFlightController.class, "sentBytes"); + private static final AtomicLongFieldUpdater receivedBytesUpdater = AtomicLongFieldUpdater.newUpdater(BytesInFlightController.class, "receivedBytes"); + + private volatile long minimumInFlightBytes, maximumInFlightBytes; + private volatile long sentBytes; + private volatile long receivedBytes; + private final ConcurrentLinkedQueue> deferredBytes = new ConcurrentLinkedQueue<>(); + private final ConcurrentSkipListMap waitingToSend = new ConcurrentSkipListMap<>(); + + BytesInFlightController(long maximumInFlightBytes) + { + this.maximumInFlightBytes = maximumInFlightBytes; + } + + public void send(long bytes) throws InterruptedException + { + long sentBytes = sentBytesUpdater.getAndAdd(this, bytes); + maybeProcessDeferred(); + if ((sentBytes - receivedBytes) >= maximumInFlightBytes) + { + long waitUntilReceived = sentBytes - maximumInFlightBytes; + // overlap shouldn't occur, but cannot guarantee it when we modify maximumInFlightBytes + Thread prev = waitingToSend.putIfAbsent(waitUntilReceived, Thread.currentThread()); + while (prev != null) + prev = waitingToSend.putIfAbsent(++waitUntilReceived, Thread.currentThread()); + + boolean isInterrupted; + while (!(isInterrupted = Thread.currentThread().isInterrupted()) + && waitUntilReceived - receivedBytes >= 0 + && waitingToSend.get(waitUntilReceived) != null) + { + LockSupport.park(); + } + waitingToSend.remove(waitUntilReceived); + + if (isInterrupted) + throw new InterruptedException(); + } + } + + public long minimumInFlightBytes() { return minimumInFlightBytes; } + public long maximumInFlightBytes() { return maximumInFlightBytes; } + + void adjust(int predictedSentBytes, int actualSentBytes) + { + receivedBytesUpdater.addAndGet(this, predictedSentBytes - actualSentBytes); + if (predictedSentBytes > actualSentBytes) wakeupSenders(); + else maybeProcessDeferred(); + } + + public long inFlight() + { + return sentBytes - receivedBytes; + } + + public void fail(int bytes) + { + receivedBytesUpdater.addAndGet(this, bytes); + wakeupSenders(); + } + + public void process(int bytes, IntConsumer releaseBytes) + { + while (true) + { + long sent = sentBytes; + long received = receivedBytes; + long newReceived = received + bytes; + if (sent - newReceived <= minimumInFlightBytes) + { + deferredBytes.add(Pair.create(bytes, releaseBytes)); + break; + } + if (receivedBytesUpdater.compareAndSet(this, received, newReceived)) + { + releaseBytes.accept(bytes); + break; + } + } + maybeProcessDeferred(); + wakeupSenders(); + } + + void setInFlightByteBounds(long minimumInFlightBytes, long maximumInFlightBytes) + { + this.minimumInFlightBytes = minimumInFlightBytes; + this.maximumInFlightBytes = maximumInFlightBytes; + maybeProcessDeferred(); + } + + // unlike the rest of the class, this method does not handle wrap-around of sent/received; + // since this shouldn't happen it's no big deal, but maybe for absurdly long runs it might. + // if so, fix it. + private void wakeupSenders() + { + Map.Entry next; + while (null != (next = waitingToSend.firstEntry())) + { + if (next.getKey() - receivedBytes >= 0) + break; + + if (waitingToSend.remove(next.getKey(), next.getValue())) + LockSupport.unpark(next.getValue()); + } + } + + private void maybeProcessDeferred() + { + while (true) + { + long sent = sentBytes; + long received = receivedBytes; + if (sent - received <= minimumInFlightBytes) + break; + + Pair next = deferredBytes.poll(); + if (next == null) + break; + + int receive = next.left; + IntConsumer callbacks = next.right; + while (true) + { + long newReceived = received + receive; + if (receivedBytesUpdater.compareAndSet(this, received, newReceived)) + { + callbacks.accept(receive); + wakeupSenders(); + break; + } + + sent = sentBytes; + received = receivedBytes; + if (sent - received <= minimumInFlightBytes) + { + deferredBytes.add(next); + break; // continues with outer loop to maybe process it if minimumInFlightBytes has changed meanwhile + } + } + } + } + +} diff --git a/test/burn/org/apache/cassandra/net/Connection.java b/test/burn/org/apache/cassandra/net/Connection.java new file mode 100644 index 000000000000..c74c0ae0a166 --- /dev/null +++ b/test/burn/org/apache/cassandra/net/Connection.java @@ -0,0 +1,397 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.net; + +import java.io.IOException; +import java.nio.channels.ClosedChannelException; +import java.util.concurrent.Executor; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.locks.Lock; +import java.util.concurrent.locks.ReentrantLock; + +import org.apache.cassandra.io.util.DataInputPlus; +import org.apache.cassandra.io.util.DataOutputPlus; +import org.apache.cassandra.locator.InetAddressAndPort; +import org.apache.cassandra.net.Verifier.Destiny; + +import static org.apache.cassandra.net.MessagingService.VERSION_40; +import static org.apache.cassandra.net.MessagingService.current_version; +import static org.apache.cassandra.utils.ExecutorUtils.runWithThreadName; +import static org.apache.cassandra.utils.MonotonicClock.approxTime; + +public class Connection implements InboundMessageCallbacks, OutboundMessageCallbacks, OutboundDebugCallbacks +{ + static class IntentionalIOException extends IOException {} + static class IntentionalRuntimeException extends RuntimeException {} + + final InetAddressAndPort sender; + final InetAddressAndPort recipient; + final BytesInFlightController controller; + final InboundMessageHandlers inbound; + final OutboundConnection outbound; + final OutboundConnectionSettings outboundTemplate; + final Verifier verifier; + final MessageGenerator sendGenerator; + final String linkId; + final long minId; + final long maxId; + final AtomicInteger isSending = new AtomicInteger(); + private volatile Runnable onSync; + final Lock managementLock = new ReentrantLock(); + + private final AtomicLong nextSendId = new AtomicLong(); + + Connection(InetAddressAndPort sender, InetAddressAndPort recipient, ConnectionType type, + InboundMessageHandlers inbound, + OutboundConnectionSettings outboundTemplate, ResourceLimits.EndpointAndGlobal reserveCapacityInBytes, + MessageGenerator generator, + long minId, long maxId) + { + this.sender = sender; + this.recipient = recipient; + this.controller = new BytesInFlightController(1 << 20); + this.sendGenerator = generator.copy(); + this.minId = minId; + this.maxId = maxId; + this.nextSendId.set(minId); + this.linkId = sender.toString(false) + "->" + recipient.toString(false) + "-" + type; + this.outboundTemplate = outboundTemplate.toEndpoint(recipient) + .withFrom(sender) + .withCallbacks(this) + .withDebugCallbacks(this); + this.inbound = inbound; + this.outbound = new OutboundConnection(type, this.outboundTemplate, reserveCapacityInBytes); + this.verifier = new Verifier(controller, outbound, inbound); + } + + void startVerifier(Runnable onFailure, Executor executor, long deadlineNanos) + { + executor.execute(runWithThreadName(() -> verifier.run(onFailure, deadlineNanos), "Verify-" + linkId)); + } + + boolean isSending() + { + return isSending.get() > 0; + } + + boolean registerSender() + { + return isSending.updateAndGet(i -> i < 0 ? i : i + 1) > 0; + } + + void unregisterSender() + { + if (isSending.updateAndGet(i -> i < 0 ? i + 1 : i - 1) == -1) + { + Runnable onSync = this.onSync; + this.onSync = null; + verifier.onSync(() -> { + onSync.run(); + isSending.set(0); + }); + } + } + + boolean setInFlightByteBounds(long minBytes, long maxBytes) + { + if (managementLock.tryLock()) + { + try + { + if (isSending.get() >= 0) + { + controller.setInFlightByteBounds(minBytes, maxBytes); + return true; + } + } + finally + { + managementLock.unlock(); + } + } + return false; + } + + void sync(Runnable onCompletion) + { + managementLock.lock(); + try + { + assert onSync == null; + assert isSending.get() >= 0; + isSending.updateAndGet(i -> -2 -i); + long previousMin = controller.minimumInFlightBytes(); + long previousMax = controller.maximumInFlightBytes(); + controller.setInFlightByteBounds(0, Long.MAX_VALUE); + onSync = () -> { + long inFlight = controller.inFlight(); + if (inFlight != 0) + verifier.logFailure("%s has %d bytes in flight, but connection is idle", linkId, inFlight); + controller.setInFlightByteBounds(previousMin, previousMax); + onCompletion.run(); + }; + unregisterSender(); + } + finally + { + managementLock.unlock(); + } + } + + void sendOne() throws InterruptedException + { + long id = nextSendId.getAndUpdate(i -> i == maxId ? minId : i + 1); + try + { + Destiny destiny = Destiny.SUCCEED; + byte realDestiny = 0; + Message msg; + synchronized (sendGenerator) + { + if (0 == sendGenerator.uniformInt(1 << 10)) + { + // abnormal destiny + realDestiny = (byte) (1 + sendGenerator.uniformInt(6)); + destiny = realDestiny <= 3 ? Destiny.FAIL_TO_SERIALIZE : Destiny.FAIL_TO_DESERIALIZE; + } + msg = sendGenerator.generate(id, realDestiny); + } + + controller.send(msg.serializedSize(current_version)); + Verifier.EnqueueMessageEvent e = verifier.onEnqueue(msg, destiny); + outbound.enqueue(msg); + e.complete(verifier); + } + catch (ClosedChannelException e) + { + // TODO: make this a tested, not illegal, state + throw new IllegalStateException(e); + } + } + + void reconnectWith(OutboundConnectionSettings template) + { + outbound.reconnectWith(template); + } + + void serialize(long id, byte[] payload, DataOutputPlus out, int messagingVersion) throws IOException + { + verifier.onSerialize(id, messagingVersion); + int firstWrite = payload.length, remainder = 0; + boolean willFail = false; + if (outbound.type() != ConnectionType.LARGE_MESSAGES || messagingVersion >= VERSION_40) + { + // We cannot (with Netty) know how many bytes make it to the network as any partially written block + // will be failed despite having partially succeeded. So to support this behaviour here, we would + // need to accept either outcome, in which case what is the point? + // TODO: it would be nice to fix this, still + willFail = outbound.type() != ConnectionType.LARGE_MESSAGES; + byte info = MessageGenerator.getInfo(payload); + switch (info) + { + case 1: + switch ((int) (id & 1)) + { + case 0: throw new IntentionalIOException(); + case 1: throw new IntentionalRuntimeException(); + } + break; + case 2: + willFail = true; + firstWrite -= (int)id % payload.length; + break; + case 3: + willFail = true; + remainder = (int)id & 65535; + break; + } + } + + MessageGenerator.writeLength(payload, out, messagingVersion); + out.write(payload, 0, firstWrite); + while (remainder > 0) + { + out.write(payload, 0, Math.min(remainder, payload.length)); + remainder -= payload.length; + } + if (!willFail) + verifier.onFinishSerializeLarge(id); + } + + byte[] deserialize(MessageGenerator.Header header, DataInputPlus in, int messagingVersion) throws IOException + { + verifier.onDeserialize(header.id, messagingVersion); + int length = header.length; + switch (header.info) + { + case 4: + switch ((int) (header.id & 1)) + { + case 0: throw new IntentionalIOException(); + case 1: throw new IntentionalRuntimeException(); + } + break; + case 5: { + length -= (int)header.id % header.length; + break; + } + case 6: { + length += (int)header.id & 65535; + break; + } + } + byte[] result = header.read(in, Math.min(header.length, length), messagingVersion); + if (length > header.length) + { + length -= header.length; + while (length >= 8) + { + in.readLong(); + length -= 8; + } + while (length-- > 0) + in.readByte(); + } + return result; + } + + public void process(Message message) + { + verifier.process(message); + } + + public void onHeaderArrived(int messageSize, Message.Header header, long timeElapsed, TimeUnit unit) + { + } + + public void onArrived(int messageSize, Message.Header header, long timeElapsed, TimeUnit unit) + { + verifier.onArrived(header.id, messageSize); + } + + public void onArrivedExpired(int messageSize, Message.Header header, boolean wasCorrupt, long timeElapsed, TimeUnit timeUnit) + { + controller.fail(messageSize); + verifier.onArrivedExpired(header.id, messageSize, wasCorrupt, timeElapsed, timeUnit); + } + + public void onArrivedCorrupt(int messageSize, Message.Header header, long timeElapsed, TimeUnit unit) + { + controller.fail(messageSize); + verifier.onFailedDeserialize(header.id, messageSize); + } + + public void onClosedBeforeArrival(int messageSize, Message.Header header, int bytesReceived, boolean wasCorrupt, boolean wasExpired) + { + controller.fail(messageSize); + verifier.onClosedBeforeArrival(header.id, messageSize); + } + + public void onFailedDeserialize(int messageSize, Message.Header header, Throwable t) + { + controller.fail(messageSize); + verifier.onFailedDeserialize(header.id, messageSize); + } + + public void onDispatched(int messageSize, Message.Header header) + { + } + + public void onExecuting(int messageSize, Message.Header header, long timeElapsed, TimeUnit unit) + { + } + + public void onProcessed(int messageSize, Message.Header header) + { + } + + public void onExpired(int messageSize, Message.Header header, long timeElapsed, TimeUnit timeUnit) + { + controller.fail(messageSize); + verifier.onProcessExpired(header.id, messageSize, timeElapsed, timeUnit); + } + + public void onExecuted(int messageSize, Message.Header header, long timeElapsed, TimeUnit unit) + { + } + + InboundCounters inboundCounters() + { + return inbound.countersFor(outbound.type()); + } + + public void onSendSmallFrame(int messageCount, int payloadSizeInBytes) + { + verifier.onSendFrame(messageCount, payloadSizeInBytes); + } + + public void onSentSmallFrame(int messageCount, int payloadSizeInBytes) + { + verifier.onSentFrame(messageCount, payloadSizeInBytes); + } + + public void onFailedSmallFrame(int messageCount, int payloadSizeInBytes) + { + controller.fail(payloadSizeInBytes); + verifier.onFailedFrame(messageCount, payloadSizeInBytes); + } + + public void onConnect(int messagingVersion, OutboundConnectionSettings settings) + { + verifier.onConnectOutbound(messagingVersion, settings); + } + + public void onConnectInbound(int messagingVersion, InboundMessageHandler handler) + { + verifier.onConnectInbound(messagingVersion, handler); + } + + public void onOverloaded(Message message, InetAddressAndPort peer) + { + controller.fail(message.serializedSize(current_version)); + verifier.onOverloaded(message.id()); + } + + public void onExpired(Message message, InetAddressAndPort peer) + { + controller.fail(message.serializedSize(current_version)); + verifier.onExpiredBeforeSend(message.id(), message.serializedSize(current_version), approxTime.now() - message.createdAtNanos(), TimeUnit.NANOSECONDS); + } + + public void onFailedSerialize(Message message, InetAddressAndPort peer, int messagingVersion, int bytesWrittenToNetwork, Throwable failure) + { + if (bytesWrittenToNetwork == 0) + controller.fail(message.serializedSize(messagingVersion)); + verifier.onFailedSerialize(message.id(), bytesWrittenToNetwork, failure); + } + + public void onDiscardOnClose(Message message, InetAddressAndPort peer) + { + controller.fail(message.serializedSize(current_version)); + verifier.onFailedClosing(message.id()); + } + + public String toString() + { + return linkId; + } +} + diff --git a/test/burn/org/apache/cassandra/net/ConnectionBurnTest.java b/test/burn/org/apache/cassandra/net/ConnectionBurnTest.java new file mode 100644 index 000000000000..81b6402c5479 --- /dev/null +++ b/test/burn/org/apache/cassandra/net/ConnectionBurnTest.java @@ -0,0 +1,656 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.net; + +import java.io.IOException; +import java.net.UnknownHostException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.ThreadLocalRandom; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import java.util.function.BiConsumer; +import java.util.function.Consumer; +import java.util.function.IntConsumer; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +import com.google.common.collect.ImmutableList; +import com.google.common.util.concurrent.Uninterruptibles; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import io.netty.channel.Channel; +import net.openhft.chronicle.core.util.ThrowingBiConsumer; +import net.openhft.chronicle.core.util.ThrowingRunnable; +import org.apache.cassandra.config.DatabaseDescriptor; +import org.apache.cassandra.io.IVersionedSerializer; +import org.apache.cassandra.io.util.DataInputPlus; +import org.apache.cassandra.io.util.DataOutputPlus; +import org.apache.cassandra.locator.InetAddressAndPort; +import org.apache.cassandra.net.MessageGenerator.UniformPayloadGenerator; +import org.apache.cassandra.utils.ExecutorUtils; +import org.apache.cassandra.utils.MonotonicClock; +import org.apache.cassandra.utils.memory.BufferPool; + +import static java.lang.Math.min; +import static org.apache.cassandra.net.MessagingService.current_version; +import static org.apache.cassandra.net.ConnectionType.LARGE_MESSAGES; +import static org.apache.cassandra.utils.MonotonicClock.approxTime; +import static org.apache.cassandra.utils.MonotonicClock.preciseTime; + +public class ConnectionBurnTest +{ + private static final Logger logger = LoggerFactory.getLogger(ConnectionBurnTest.class); + + static + { + // stop updating ALMOST_SAME_TIME so that we get consistent message expiration times + ((MonotonicClock.AbstractEpochSamplingClock) preciseTime).pauseEpochSampling(); + DatabaseDescriptor.daemonInitialization(); + DatabaseDescriptor.setCrossNodeTimeout(true); + } + + static class NoGlobalInboundMetrics implements InboundMessageHandlers.GlobalMetricCallbacks + { + static final NoGlobalInboundMetrics instance = new NoGlobalInboundMetrics(); + public LatencyConsumer internodeLatencyRecorder(InetAddressAndPort to) + { + return (timeElapsed, timeUnit) -> {}; + } + public void recordInternalLatency(Verb verb, long timeElapsed, TimeUnit timeUnit) {} + public void recordInternodeDroppedMessage(Verb verb, long timeElapsed, TimeUnit timeUnit) {} + } + + static class Inbound + { + final Map> handlersByRecipientThenSender; + final InboundSockets sockets; + + Inbound(List endpoints, GlobalInboundSettings settings, Test test) + { + final InboundMessageHandlers.GlobalResourceLimits globalInboundLimits = new InboundMessageHandlers.GlobalResourceLimits(new ResourceLimits.Concurrent(settings.globalReserveLimit)); + Map> handlersByRecipientThenSender = new HashMap<>(); + List bind = new ArrayList<>(); + for (InetAddressAndPort recipient : endpoints) + { + Map handlersBySender = new HashMap<>(); + for (InetAddressAndPort sender : endpoints) + handlersBySender.put(sender, new InboundMessageHandlers(recipient, sender, settings.queueCapacity, settings.endpointReserveLimit, globalInboundLimits, NoGlobalInboundMetrics.instance, test, test)); + + handlersByRecipientThenSender.put(recipient, handlersBySender); + bind.add(settings.template.withHandlers(handlersBySender::get).withBindAddress(recipient)); + } + this.sockets = new InboundSockets(bind); + this.handlersByRecipientThenSender = handlersByRecipientThenSender; + } + } + + private static class ConnectionKey + { + final InetAddressAndPort from; + final InetAddressAndPort to; + final ConnectionType type; + + private ConnectionKey(InetAddressAndPort from, InetAddressAndPort to, ConnectionType type) + { + this.from = from; + this.to = to; + this.type = type; + } + + public boolean equals(Object o) + { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + ConnectionKey that = (ConnectionKey) o; + return Objects.equals(from, that.from) && + Objects.equals(to, that.to) && + type == that.type; + } + + public int hashCode() + { + return Objects.hash(from, to, type); + } + } + + private static class Test implements InboundMessageHandlers.HandlerProvider, InboundMessageHandlers.MessageConsumer + { + private final IVersionedSerializer serializer = new IVersionedSerializer() + { + public void serialize(byte[] payload, DataOutputPlus out, int version) throws IOException + { + long id = MessageGenerator.getId(payload); + forId(id).serialize(id, payload, out, version); + } + + public byte[] deserialize(DataInputPlus in, int version) throws IOException + { + MessageGenerator.Header header = MessageGenerator.readHeader(in, version); + return forId(header.id).deserialize(header, in, version); + } + + public long serializedSize(byte[] payload, int version) + { + return MessageGenerator.serializedSize(payload, version); + } + }; + + static class Builder + { + long time; + TimeUnit timeUnit; + int endpoints; + MessageGenerators generators; + OutboundConnectionSettings outbound; + GlobalInboundSettings inbound; + public Builder time(long time, TimeUnit timeUnit) { this.time = time; this.timeUnit = timeUnit; return this; } + public Builder endpoints(int endpoints) { this.endpoints = endpoints; return this; } + public Builder inbound(GlobalInboundSettings inbound) { this.inbound = inbound; return this; } + public Builder outbound(OutboundConnectionSettings outbound) { this.outbound = outbound; return this; } + public Builder generators(MessageGenerators generators) { this.generators = generators; return this; } + Test build() { return new Test(endpoints, generators, inbound, outbound, timeUnit.toNanos(time)); } + } + + static Builder builder() { return new Builder(); } + + private static final int messageIdsPerConnection = 1 << 20; + + final long runForNanos; + final int version; + final List endpoints; + final Inbound inbound; + final Connection[] connections; + final long[] connectionMessageIds; + final ExecutorService executor = Executors.newCachedThreadPool(); + final Map connectionLookup = new HashMap<>(); + + private Test(int simulateEndpoints, MessageGenerators messageGenerators, GlobalInboundSettings inboundSettings, OutboundConnectionSettings outboundTemplate, long runForNanos) + { + this.endpoints = endpoints(simulateEndpoints); + this.inbound = new Inbound(endpoints, inboundSettings, this); + this.connections = new Connection[endpoints.size() * endpoints.size() * 3]; + this.connectionMessageIds = new long[connections.length]; + this.version = outboundTemplate.acceptVersions == null ? current_version : outboundTemplate.acceptVersions.max; + this.runForNanos = runForNanos; + + int i = 0; + long minId = 0, maxId = messageIdsPerConnection - 1; + for (InetAddressAndPort recipient : endpoints) + { + for (InetAddressAndPort sender : endpoints) + { + InboundMessageHandlers inboundHandlers = inbound.handlersByRecipientThenSender.get(recipient).get(sender); + OutboundConnectionSettings template = outboundTemplate.withDefaultReserveLimits(); + ResourceLimits.Limit reserveEndpointCapacityInBytes = new ResourceLimits.Concurrent(template.applicationSendQueueReserveEndpointCapacityInBytes); + ResourceLimits.EndpointAndGlobal reserveCapacityInBytes = new ResourceLimits.EndpointAndGlobal(reserveEndpointCapacityInBytes, template.applicationSendQueueReserveGlobalCapacityInBytes); + for (ConnectionType type : ConnectionType.MESSAGING_TYPES) + { + Connection connection = new Connection(sender, recipient, type, inboundHandlers, template, reserveCapacityInBytes, messageGenerators.get(type), minId, maxId); + this.connections[i] = connection; + this.connectionMessageIds[i] = minId; + connectionLookup.put(new ConnectionKey(sender, recipient, type), connection); + minId = maxId + 1; + maxId += messageIdsPerConnection; + ++i; + } + } + } + } + + Connection forId(long messageId) + { + int i = Arrays.binarySearch(connectionMessageIds, messageId); + if (i < 0) i = -2 -i; + Connection connection = connections[i]; + assert connection.minId <= messageId && connection.maxId >= messageId; + return connection; + } + + List getConnections(InetAddressAndPort endpoint, boolean inbound) + { + List result = new ArrayList<>(); + for (ConnectionType type : ConnectionType.MESSAGING_TYPES) + { + for (InetAddressAndPort other : endpoints) + { + result.add(connectionLookup.get(inbound ? new ConnectionKey(other, endpoint, type) + : new ConnectionKey(endpoint, other, type))); + } + } + result.forEach(c -> {assert endpoint.equals(inbound ? c.recipient : c.sender); }); + return result; + } + + public void run() throws ExecutionException, InterruptedException, NoSuchFieldException, IllegalAccessException, TimeoutException + { + Reporters reporters = new Reporters(endpoints, connections); + try + { + long deadline = System.nanoTime() + runForNanos; + Verb._TEST_2.unsafeSetHandler(() -> message -> {}); + Verb._TEST_2.unsafeSetSerializer(() -> serializer); + inbound.sockets.open().get(); + + CountDownLatch failed = new CountDownLatch(1); + for (Connection connection : connections) + connection.startVerifier(failed::countDown, executor, deadline); + + for (int i = 0 ; i < 2 * connections.length ; ++i) + { + executor.execute(() -> { + String threadName = Thread.currentThread().getName(); + try + { + ThreadLocalRandom random = ThreadLocalRandom.current(); + while (approxTime.now() < deadline && !Thread.currentThread().isInterrupted()) + { + Connection connection = connections[random.nextInt(connections.length)]; + if (!connection.registerSender()) + continue; + + try + { + Thread.currentThread().setName("Generate-" + connection.linkId); + int count = 0; + switch (random.nextInt() & 3) + { + case 0: count = random.nextInt(100, 200); break; + case 1: count = random.nextInt(200, 1000); break; + case 2: count = random.nextInt(1000, 2000); break; + case 3: count = random.nextInt(2000, 10000); break; + } + + if (connection.outbound.type() == LARGE_MESSAGES) + count /= 2; + + while (connection.isSending() + && count-- > 0 + && approxTime.now() < deadline + && !Thread.currentThread().isInterrupted()) + connection.sendOne(); + } + finally + { + Thread.currentThread().setName(threadName); + connection.unregisterSender(); + } + } + } + catch (Throwable t) + { + if (t instanceof InterruptedException) + return; + logger.error("Unexpected exception", t); + failed.countDown(); + } + }); + } + + executor.execute(() -> { + Thread.currentThread().setName("Test-SetInFlight"); + ThreadLocalRandom random = ThreadLocalRandom.current(); + List connections = new ArrayList<>(Arrays.asList(this.connections)); + while (!Thread.currentThread().isInterrupted()) + { + Collections.shuffle(connections); + int total = random.nextInt(1 << 20, 128 << 20); + for (int i = connections.size() - 1; i >= 1 ; --i) + { + int average = total / (i + 1); + int max = random.nextInt(1, min(2 * average, total - 2)); + int min = random.nextInt(0, max); + connections.get(i).setInFlightByteBounds(min, max); + total -= max; + } + // note that setInFlightByteBounds might not + connections.get(0).setInFlightByteBounds(random.nextInt(0, total), total); + Uninterruptibles.sleepUninterruptibly(1L, TimeUnit.SECONDS); + } + }); + + // TODO: slowly modify the pattern of interrupts, from often to infrequent + executor.execute(() -> { + Thread.currentThread().setName("Test-Reconnect"); + ThreadLocalRandom random = ThreadLocalRandom.current(); + while (deadline > System.nanoTime()) + { + try + { + Thread.sleep(random.nextInt(60000)); + } + catch (InterruptedException e) + { + break; + } + Connection connection = connections[random.nextInt(connections.length)]; + OutboundConnectionSettings template = connection.outboundTemplate; + template = ConnectionTest.SETTINGS.get(random.nextInt(ConnectionTest.SETTINGS.size())) + .outbound.apply(template); + connection.reconnectWith(template); + } + }); + + executor.execute(() -> { + Thread.currentThread().setName("Test-Sync"); + ThreadLocalRandom random = ThreadLocalRandom.current(); + BiConsumer> checkStoppedTo = (to, from) -> { + InboundMessageHandlers handlers = from.get(0).inbound; + long using = handlers.usingCapacity(); + long usingReserve = handlers.usingEndpointReserveCapacity(); + if (using != 0 || usingReserve != 0) + { + String message = to + " inbound using %d capacity and %d reserve; should be zero"; + from.get(0).verifier.logFailure(message, using, usingReserve); + } + }; + BiConsumer> checkStoppedFrom = (from, to) -> { + long using = to.stream().map(c -> c.outbound).mapToLong(OutboundConnection::pendingBytes).sum(); + long usingReserve = to.get(0).outbound.unsafeGetEndpointReserveLimits().using(); + if (using != 0 || usingReserve != 0) + { + String message = from + " outbound using %d capacity and %d reserve; should be zero"; + to.get(0).verifier.logFailure(message, using, usingReserve); + } + }; + ThrowingBiConsumer, ThrowingRunnable, InterruptedException> sync = + (connections, exec) -> { + logger.info("Syncing connections: {}", connections); + final CountDownLatch ready = new CountDownLatch(connections.size()); + final CountDownLatch done = new CountDownLatch(1); + for (Connection connection : connections) + { + connection.sync(() -> { + ready.countDown(); + try { done.await(); } + catch (InterruptedException e) { Thread.interrupted(); } + }); + } + ready.await(); + try + { + exec.run(); + } + finally + { + done.countDown(); + } + logger.info("Sync'd connections: {}", connections); + }; + + int count = 0; + while (deadline > System.nanoTime()) + { + + try + { + Thread.sleep(random.nextInt(10000)); + + if (++count % 10 == 0) +// { +// boolean checkInbound = random.nextBoolean(); +// BiConsumer> verifier = checkInbound ? checkStoppedTo : checkStoppedFrom; +// InetAddressAndPort endpoint = endpoints.get(random.nextInt(endpoints.size())); +// List connections = getConnections(endpoint, checkInbound); +// sync.accept(connections, () -> verifier.accept(endpoint, connections)); +// } +// else if (count % 100 == 0) + { + sync.accept(ImmutableList.copyOf(connections), () -> { + + for (InetAddressAndPort endpoint : endpoints) + { + checkStoppedTo .accept(endpoint, getConnections(endpoint, true )); + checkStoppedFrom.accept(endpoint, getConnections(endpoint, false)); + } + long inUse = BufferPool.unsafeGetBytesInUse(); + if (inUse > 0) + { +// try +// { +// ManagementFactory.getPlatformMXBean(HotSpotDiagnosticMXBean.class).dumpHeap("/Users/belliottsmith/code/cassandra/cassandra/leak.hprof", true); +// } +// catch (IOException e) +// { +// throw new RuntimeException(e); +// } + connections[0].verifier.logFailure("Using %d bytes of BufferPool, but all connections are idle", inUse); + } + }); + } + else + { + CountDownLatch latch = new CountDownLatch(1); + Connection connection = connections[random.nextInt(connections.length)]; + connection.sync(latch::countDown); + latch.await(); + } + } + catch (InterruptedException e) + { + break; + } + } + }); + + while (deadline > System.nanoTime() && failed.getCount() > 0) + { + reporters.update(); + reporters.print(); + Uninterruptibles.awaitUninterruptibly(failed, 30L, TimeUnit.SECONDS); + } + + executor.shutdownNow(); + ExecutorUtils.awaitTermination(1L, TimeUnit.MINUTES, executor); + } + finally + { + reporters.update(); + reporters.print(); + + inbound.sockets.close().get(); + new FutureCombiner(Arrays.stream(connections) + .map(c -> c.outbound.close(false)) + .collect(Collectors.toList())) + .get(); + } + } + + class WrappedInboundCallbacks implements InboundMessageCallbacks + { + private final InboundMessageCallbacks wrapped; + + WrappedInboundCallbacks(InboundMessageCallbacks wrapped) + { + this.wrapped = wrapped; + } + + public void onHeaderArrived(int messageSize, Message.Header header, long timeElapsed, TimeUnit unit) + { + forId(header.id).onHeaderArrived(messageSize, header, timeElapsed, unit); + wrapped.onHeaderArrived(messageSize, header, timeElapsed, unit); + } + + public void onArrived(int messageSize, Message.Header header, long timeElapsed, TimeUnit unit) + { + forId(header.id).onArrived(messageSize, header, timeElapsed, unit); + wrapped.onArrived(messageSize, header, timeElapsed, unit); + } + + public void onArrivedExpired(int messageSize, Message.Header header, boolean wasCorrupt, long timeElapsed, TimeUnit unit) + { + forId(header.id).onArrivedExpired(messageSize, header, wasCorrupt, timeElapsed, unit); + wrapped.onArrivedExpired(messageSize, header, wasCorrupt, timeElapsed, unit); + } + + public void onArrivedCorrupt(int messageSize, Message.Header header, long timeElapsed, TimeUnit unit) + { + forId(header.id).onArrivedCorrupt(messageSize, header, timeElapsed, unit); + wrapped.onArrivedCorrupt(messageSize, header, timeElapsed, unit); + } + + public void onClosedBeforeArrival(int messageSize, Message.Header header, int bytesReceived, boolean wasCorrupt, boolean wasExpired) + { + forId(header.id).onClosedBeforeArrival(messageSize, header, bytesReceived, wasCorrupt, wasExpired); + wrapped.onClosedBeforeArrival(messageSize, header, bytesReceived, wasCorrupt, wasExpired); + } + + public void onFailedDeserialize(int messageSize, Message.Header header, Throwable t) + { + forId(header.id).onFailedDeserialize(messageSize, header, t); + wrapped.onFailedDeserialize(messageSize, header, t); + } + + public void onDispatched(int messageSize, Message.Header header) + { + forId(header.id).onDispatched(messageSize, header); + wrapped.onDispatched(messageSize, header); + } + + public void onExecuting(int messageSize, Message.Header header, long timeElapsed, TimeUnit unit) + { + forId(header.id).onExecuting(messageSize, header, timeElapsed, unit); + wrapped.onExecuting(messageSize, header, timeElapsed, unit); + } + + public void onProcessed(int messageSize, Message.Header header) + { + forId(header.id).onProcessed(messageSize, header); + wrapped.onProcessed(messageSize, header); + } + + public void onExpired(int messageSize, Message.Header header, long timeElapsed, TimeUnit unit) + { + forId(header.id).onExpired(messageSize, header, timeElapsed, unit); + wrapped.onExpired(messageSize, header, timeElapsed, unit); + } + + public void onExecuted(int messageSize, Message.Header header, long timeElapsed, TimeUnit unit) + { + forId(header.id).onExecuted(messageSize, header, timeElapsed, unit); + wrapped.onExecuted(messageSize, header, timeElapsed, unit); + } + } + + public void fail(Message.Header header, Throwable failure) + { +// forId(header.id).verifier.logFailure("Unexpected failure", failure); + } + + public void accept(Message message) + { + forId(message.id()).process(message); + } + + public InboundMessageCallbacks wrap(InboundMessageCallbacks wrap) + { + return new WrappedInboundCallbacks(wrap); + } + + public InboundMessageHandler provide( + FrameDecoder decoder, + ConnectionType type, + Channel channel, + InetAddressAndPort self, + InetAddressAndPort peer, + int version, + int largeMessageThreshold, + int queueCapacity, + ResourceLimits.Limit endpointReserveCapacity, + ResourceLimits.Limit globalReserveCapacity, + InboundMessageHandler.WaitQueue endpointWaitQueue, + InboundMessageHandler.WaitQueue globalWaitQueue, + InboundMessageHandler.OnHandlerClosed onClosed, + InboundMessageCallbacks callbacks, + Consumer> messageSink + ) + { + return new InboundMessageHandler(decoder, type, channel, self, peer, version, largeMessageThreshold, queueCapacity, endpointReserveCapacity, globalReserveCapacity, endpointWaitQueue, globalWaitQueue, onClosed, wrap(callbacks), messageSink) + { + final IntConsumer releaseCapacity = size -> super.releaseProcessedCapacity(size, null); + protected void releaseProcessedCapacity(int bytes, Message.Header header) + { + forId(header.id).controller.process(bytes, releaseCapacity); + } + }; + } + } + + static List endpoints(int count) + { + return IntStream.rangeClosed(1, count) + .mapToObj(ConnectionBurnTest::endpoint) + .collect(Collectors.toList()); + } + + private static InetAddressAndPort endpoint(int i) + { + try + { + return InetAddressAndPort.getByName("127.0.0." + i); + } + catch (UnknownHostException e) + { + throw new RuntimeException(e); + } + } + + public static void test(GlobalInboundSettings inbound, OutboundConnectionSettings outbound) throws ExecutionException, InterruptedException, NoSuchFieldException, IllegalAccessException, TimeoutException + { + MessageGenerator small = new UniformPayloadGenerator(0, 1, (1 << 15)); + MessageGenerator large = new UniformPayloadGenerator(0, 1, (1 << 16) + (1 << 15)); + MessageGenerators generators = new MessageGenerators(small, large); + outbound = outbound.withApplicationSendQueueCapacityInBytes(1 << 18) + .withApplicationReserveSendQueueCapacityInBytes(1 << 30, new ResourceLimits.Concurrent(Integer.MAX_VALUE)); + + Test.builder() + .generators(generators) + .endpoints(4) + .inbound(inbound) + .outbound(outbound) + .time(2L, TimeUnit.DAYS) + .build().run(); + } + + public static void main(String[] args) throws ExecutionException, InterruptedException, NoSuchFieldException, IllegalAccessException, TimeoutException + { + GlobalInboundSettings inboundSettings = new GlobalInboundSettings() + .withQueueCapacity(1 << 18) + .withEndpointReserveLimit(1 << 20) + .withGlobalReserveLimit(1 << 21) + .withTemplate(new InboundConnectionSettings() + .withEncryption(ConnectionTest.encryptionOptions)); + + test(inboundSettings, new OutboundConnectionSettings(null) + .withTcpUserTimeoutInMS(0)); + MessagingService.instance().socketFactory.shutdownNow(); + } + +} diff --git a/test/burn/org/apache/cassandra/net/GlobalInboundSettings.java b/test/burn/org/apache/cassandra/net/GlobalInboundSettings.java new file mode 100644 index 000000000000..9b23041ae41c --- /dev/null +++ b/test/burn/org/apache/cassandra/net/GlobalInboundSettings.java @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.net; + +class GlobalInboundSettings +{ + final int queueCapacity; + final long endpointReserveLimit; + final long globalReserveLimit; + final InboundConnectionSettings template; + + GlobalInboundSettings() + { + this(0, 0, 0, null); + } + + GlobalInboundSettings(int queueCapacity, long endpointReserveLimit, long globalReserveLimit, InboundConnectionSettings template) + { + this.queueCapacity = queueCapacity; + this.endpointReserveLimit = endpointReserveLimit; + this.globalReserveLimit = globalReserveLimit; + this.template = template; + } + + GlobalInboundSettings withQueueCapacity(int queueCapacity) + { + return new GlobalInboundSettings(queueCapacity, endpointReserveLimit, globalReserveLimit, template); + } + GlobalInboundSettings withEndpointReserveLimit(int endpointReserveLimit) + { + return new GlobalInboundSettings(queueCapacity, endpointReserveLimit, globalReserveLimit, template); + } + GlobalInboundSettings withGlobalReserveLimit(int globalReserveLimit) + { + return new GlobalInboundSettings(queueCapacity, endpointReserveLimit, globalReserveLimit, template); + } + GlobalInboundSettings withTemplate(InboundConnectionSettings template) + { + return new GlobalInboundSettings(queueCapacity, endpointReserveLimit, globalReserveLimit, template); + } +} \ No newline at end of file diff --git a/test/burn/org/apache/cassandra/net/LogbackFilter.java b/test/burn/org/apache/cassandra/net/LogbackFilter.java new file mode 100644 index 000000000000..94aa2f9a6a71 --- /dev/null +++ b/test/burn/org/apache/cassandra/net/LogbackFilter.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.net; + +import java.io.EOFException; +import java.nio.BufferOverflowException; +import java.util.Set; +import java.util.regex.Pattern; + +import com.google.common.collect.ImmutableSet; + +import ch.qos.logback.classic.spi.IThrowableProxy; +import ch.qos.logback.classic.spi.LoggingEvent; +import ch.qos.logback.core.filter.Filter; +import ch.qos.logback.core.spi.FilterReply; + +public class LogbackFilter extends Filter +{ + private static final Pattern ignore = Pattern.compile("(successfully connected|connection established), version ="); + + public FilterReply decide(Object o) + { + if (!(o instanceof LoggingEvent)) + return FilterReply.NEUTRAL; + + LoggingEvent e = (LoggingEvent) o; +// if (ignore.matcher(e.getMessage()).find()) +// return FilterReply.DENY; + + IThrowableProxy t = e.getThrowableProxy(); + if (t == null) + return FilterReply.NEUTRAL; + + if (!isIntentional(t)) + return FilterReply.NEUTRAL; + +// logger.info("Filtered exception {}: {}", t.getClassName(), t.getMessage()); + return FilterReply.DENY; + } + + private static final Set intentional = ImmutableSet.of( + Connection.IntentionalIOException.class.getName(), + Connection.IntentionalRuntimeException.class.getName(), + InvalidSerializedSizeException.class.getName(), + BufferOverflowException.class.getName(), + EOFException.class.getName() + ); + + public static boolean isIntentional(IThrowableProxy t) + { + while (true) + { + if (intentional.contains(t.getClassName())) + return true; + + if (null == t.getCause()) + return false; + + t = t.getCause(); + } + } + + +} diff --git a/test/burn/org/apache/cassandra/net/MessageGenerator.java b/test/burn/org/apache/cassandra/net/MessageGenerator.java new file mode 100644 index 000000000000..3c03689a566e --- /dev/null +++ b/test/burn/org/apache/cassandra/net/MessageGenerator.java @@ -0,0 +1,190 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.net; + +import java.io.IOException; +import java.lang.reflect.Field; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.Random; +import java.util.concurrent.TimeUnit; + +import org.apache.cassandra.io.util.DataInputPlus; +import org.apache.cassandra.io.util.DataOutputPlus; +import org.apache.cassandra.utils.vint.VIntCoding; +import sun.misc.Unsafe; + +import static org.apache.cassandra.net.MessagingService.VERSION_40; +import static org.apache.cassandra.utils.MonotonicClock.approxTime; + +abstract class MessageGenerator +{ + final long seed; + final Random random; + + private MessageGenerator(long seed) + { + this.seed = seed; + this.random = new Random(); + } + + Message.Builder builder(long id) + { + random.setSeed(id ^ seed); + long now = approxTime.now(); + + int expiresInMillis; + int expiryMask = random.nextInt(); + if (0 == (expiryMask & 0xffff)) expiresInMillis = 2; + else if (0 == (expiryMask & 0xfff)) expiresInMillis = 10; + else if (0 == (expiryMask & 0xff)) expiresInMillis = 100; + else if (0 == (expiryMask & 0xf)) expiresInMillis = 1000; + else expiresInMillis = 60 * 1000; + + long expiresInNanos = TimeUnit.MILLISECONDS.toNanos((expiresInMillis / 2) + random.nextInt(expiresInMillis / 2)); + + return Message.builder(Verb._TEST_2, null) + .withId(id) + .withCreatedAt(now) + .withExpiresAt(now + expiresInNanos); // don't expire for now + } + + public int uniformInt(int limit) + { + return random.nextInt(limit); + } + + // generate a Message with the provided id and with both id and info encoded in its payload + abstract Message generate(long id, byte info); + abstract MessageGenerator copy(); + + static final class UniformPayloadGenerator extends MessageGenerator + { + final int minSize; + final int maxSize; + final byte[] fillWithBytes; + UniformPayloadGenerator(long seed, int minSize, int maxSize) + { + super(seed); + this.minSize = Math.max(9, minSize); + this.maxSize = Math.max(9, maxSize); + this.fillWithBytes = new byte[32]; + random.setSeed(seed); + random.nextBytes(fillWithBytes); + } + + Message generate(long id, byte info) + { + Message.Builder builder = builder(id); + byte[] payload = new byte[minSize + random.nextInt(maxSize - minSize)]; + ByteBuffer wrapped = ByteBuffer.wrap(payload); + setId(payload, id); + payload[8] = info; + wrapped.position(9); + while (wrapped.hasRemaining()) + wrapped.put(fillWithBytes, 0, Math.min(fillWithBytes.length, wrapped.remaining())); + builder.withPayload(payload); + return builder.build(); + } + + MessageGenerator copy() + { + return new UniformPayloadGenerator(seed, minSize, maxSize); + } + } + + static long getId(byte[] payload) + { + return unsafe.getLong(payload, BYTE_ARRAY_BASE_OFFSET); + } + static byte getInfo(byte[] payload) + { + return payload[8]; + } + private static void setId(byte[] payload, long id) + { + unsafe.putLong(payload, BYTE_ARRAY_BASE_OFFSET, id); + } + + static class Header + { + public final int length; + public final long id; + public final byte info; + + Header(int length, long id, byte info) + { + this.length = length; + this.id = id; + this.info = info; + } + + public byte[] read(DataInputPlus in, int length, int messagingVersion) throws IOException + { + byte[] result = new byte[Math.max(9, length)]; + setId(result, id); + result[8] = info; + in.readFully(result, 9, Math.max(0, length - 9)); + return result; + } + } + + static Header readHeader(DataInputPlus in, int messagingVersion) throws IOException + { + int length = messagingVersion < VERSION_40 + ? in.readInt() + : (int) in.readUnsignedVInt(); + long id = in.readLong(); + if (ByteOrder.nativeOrder() == ByteOrder.LITTLE_ENDIAN) + id = Long.reverseBytes(id); + byte info = in.readByte(); + return new Header(length, id, info); + } + + static void writeLength(byte[] payload, DataOutputPlus out, int messagingVersion) throws IOException + { + if (messagingVersion < VERSION_40) + out.writeInt(payload.length); + else + out.writeUnsignedVInt(payload.length); + } + + static long serializedSize(byte[] payload, int messagingVersion) + { + return payload.length + (messagingVersion < VERSION_40 ? 4 : VIntCoding.computeUnsignedVIntSize(payload.length)); + } + + private static final Unsafe unsafe; + static + { + try + { + Field field = sun.misc.Unsafe.class.getDeclaredField("theUnsafe"); + field.setAccessible(true); + unsafe = (sun.misc.Unsafe) field.get(null); + } + catch (Exception e) + { + throw new AssertionError(e); + } + } + private static final long BYTE_ARRAY_BASE_OFFSET = unsafe.arrayBaseOffset(byte[].class); + +} + diff --git a/test/burn/org/apache/cassandra/net/MessageGenerators.java b/test/burn/org/apache/cassandra/net/MessageGenerators.java new file mode 100644 index 000000000000..92aab3a4a388 --- /dev/null +++ b/test/burn/org/apache/cassandra/net/MessageGenerators.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.net; + +final class MessageGenerators +{ + final MessageGenerator small; + final MessageGenerator large; + + MessageGenerators(MessageGenerator small, MessageGenerator large) + { + this.small = small; + this.large = large; + } + + MessageGenerator get(ConnectionType type) + { + switch (type) + { + case SMALL_MESSAGES: + case URGENT_MESSAGES: + return small; + case LARGE_MESSAGES: + return large; + default: + throw new IllegalStateException(); + } + } +} diff --git a/test/burn/org/apache/cassandra/net/Reporters.java b/test/burn/org/apache/cassandra/net/Reporters.java new file mode 100644 index 000000000000..9ab46438a317 --- /dev/null +++ b/test/burn/org/apache/cassandra/net/Reporters.java @@ -0,0 +1,322 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.net; + +import java.util.Arrays; +import java.util.BitSet; +import java.util.Collection; +import java.util.List; +import java.util.concurrent.TimeUnit; +import java.util.function.LongFunction; +import java.util.function.ToLongFunction; + +import com.google.common.collect.ImmutableList; + +import org.apache.cassandra.locator.InetAddressAndPort; + +class Reporters +{ + final Collection endpoints; + final Connection[] connections; + final List reporters; + final long start = System.nanoTime(); + + Reporters(Collection endpoints, Connection[] connections) + { + this.endpoints = endpoints; + this.connections = connections; + this.reporters = ImmutableList.of( + outboundReporter (true, "Outbound Throughput", OutboundConnection::sentBytes, Reporters::prettyPrintMemory), + inboundReporter (true, "Inbound Throughput", InboundCounters::processedBytes, Reporters::prettyPrintMemory), + + outboundReporter (false, "Outbound Pending Bytes", OutboundConnection::pendingBytes, Reporters::prettyPrintMemory), + reporter (false, "Inbound Pending Bytes", c -> c.inbound.usingCapacity(), Reporters::prettyPrintMemory), + + outboundReporter (true, "Outbound Expirations", OutboundConnection::expiredCount, Long::toString), + inboundReporter (true, "Inbound Expirations", InboundCounters::expiredCount, Long::toString), + + outboundReporter (true, "Outbound Errors", OutboundConnection::errorCount, Long::toString), + inboundReporter (true, "Inbound Errors", InboundCounters::errorCount, Long::toString), + + outboundReporter (true, "Outbound Connection Attempts", OutboundConnection::connectionAttempts, Long::toString) + ); + } + + void update() + { + for (Reporter reporter : reporters) + reporter.update(); + } + + void print() + { + System.out.println("==" + prettyPrintElapsed(System.nanoTime() - start) + "==\n"); + + for (Reporter reporter : reporters) + { + reporter.print(); + } + } + + private Reporter outboundReporter(boolean accumulates, String name, ToLongFunction get, LongFunction printer) + { + return new Reporter(accumulates, name, (conn) -> get.applyAsLong(conn.outbound), printer); + } + + private Reporter inboundReporter(boolean accumulates, String name, ToLongFunction get, LongFunction printer) + { + return new Reporter(accumulates, name, (conn) -> get.applyAsLong(conn.inboundCounters()), printer); + } + + private Reporter reporter(boolean accumulates, String name, ToLongFunction get, LongFunction printer) + { + return new Reporter(accumulates, name, get, printer); + } + + class Reporter + { + boolean accumulates; + final String name; + final ToLongFunction get; + final LongFunction print; + final long[][] previousValue; + final long[] columnTotals = new long[1 + endpoints.size() * 3]; + final Table table; + + Reporter(boolean accumulates, String name, ToLongFunction get, LongFunction print) + { + this.accumulates = accumulates; + this.name = name; + this.get = get; + this.print = print; + + previousValue = accumulates ? new long[endpoints.size()][endpoints.size() * 3] : null; + + String[] rowNames = new String[endpoints.size() + 1]; + for (int row = 0 ; row < endpoints.size() ; ++row) + { + rowNames[row] = Integer.toString(1 + row); + } + rowNames[rowNames.length - 1] = "Total"; + + String[] columnNames = new String[endpoints.size() * 3 + 1]; + for (int column = 0 ; column < endpoints.size() * 3 ; column += 3) + { + String endpoint = Integer.toString(1 + column / 3); + columnNames[ column] = endpoint + ".Urgent"; + columnNames[1 + column] = endpoint + ".Small"; + columnNames[2 + column] = endpoint + ".Large"; + } + columnNames[columnNames.length - 1] = "Total"; + + table = new Table(rowNames, columnNames, "Recipient"); + } + + public void update() + { + Arrays.fill(columnTotals, 0); + int row = 0, connection = 0; + for (InetAddressAndPort recipient : endpoints) + { + int column = 0; + long rowTotal = 0; + for (InetAddressAndPort sender : endpoints) + { + for (ConnectionType type : ConnectionType.MESSAGING_TYPES) + { + assert recipient.equals(connections[connection].recipient); + assert sender.equals(connections[connection].sender); + assert type == connections[connection].outbound.type(); + + long cur = get.applyAsLong(connections[connection]); + long value; + if (accumulates) + { + long prev = previousValue[row][column]; + previousValue[row][column] = cur; + value = cur - prev; + } + else + { + value = cur; + } + table.set(row, column, print.apply(value)); + columnTotals[column] += value; + rowTotal += value; + ++column; + ++connection; + } + } + columnTotals[column] += rowTotal; + table.set(row, column, print.apply(rowTotal)); + table.displayRow(row, rowTotal > 0); + ++row; + } + + boolean displayTotalRow = false; + for (int column = 0 ; column < columnTotals.length ; ++column) + { + table.set(endpoints.size(), column, print.apply(columnTotals[column])); + table.displayColumn(column, columnTotals[column] > 0); + displayTotalRow |= columnTotals[column] > 0; + } + table.displayRow(endpoints.size(), displayTotalRow); + } + + public void print() + { + table.print("===" + name + "==="); + } + } + + private static class Table + { + final String[][] print; + final int[] width; + final BitSet rowMask = new BitSet(); + final BitSet columnMask = new BitSet(); + + public Table(String[] rowNames, String[] columnNames, String rowNameHeader) + { + print = new String[rowNames.length + 1][columnNames.length + 1]; + width = new int[columnNames.length + 1]; + print[0][0] = rowNameHeader; + for (int i = 0 ; i < columnNames.length ; ++i) + print[0][1 + i] = columnNames[i]; + for (int i = 0 ; i < rowNames.length ; ++i) + print[1 + i][0] = rowNames[i]; + } + + void set(int row, int column, String value) + { + print[row + 1][column + 1] = value; + } + + void displayRow(int row, boolean display) + { + rowMask.set(row, display); + } + + void displayColumn(int column, boolean display) + { + columnMask.set(column, display); + } + + void print(String heading) + { + if (rowMask.isEmpty() && columnMask.isEmpty()) + return; + + System.out.println(heading + '\n'); + + Arrays.fill(width, 0); + for (int row = 0 ; row < print.length ; ++row) + { + for (int column = 0 ; column < width.length ; ++column) + { + width[column] = Math.max(width[column], print[row][column].length()); + } + } + + for (int row = 0 ; row < print.length ; ++row) + { +// if (row > 0 && !rowMask.get(row - 1)) +// continue; + + StringBuilder builder = new StringBuilder(); + for (int column = 0 ; column < width.length ; ++column) + { +// if (column > 0 && !columnMask.get(column - 1)) +// continue; + + String s = print[row][column]; + int pad = width[column] - s.length(); + for (int i = 0 ; i < pad ; ++i) + builder.append(' '); + builder.append(s); + builder.append(" "); + } + System.out.println(builder.toString()); + } + System.out.println(); + } + } + + private static final class OneTimeUnit + { + final TimeUnit unit; + final String symbol; + final long nanos; + + private OneTimeUnit(TimeUnit unit, String symbol) + { + this.unit = unit; + this.symbol = symbol; + this.nanos = unit.toNanos(1L); + } + } + + private static final List prettyPrintElapsed = ImmutableList.of( + new OneTimeUnit(TimeUnit.DAYS, "d"), + new OneTimeUnit(TimeUnit.HOURS, "h"), + new OneTimeUnit(TimeUnit.MINUTES, "m"), + new OneTimeUnit(TimeUnit.SECONDS, "s"), + new OneTimeUnit(TimeUnit.MILLISECONDS, "ms"), + new OneTimeUnit(TimeUnit.MICROSECONDS, "us"), + new OneTimeUnit(TimeUnit.NANOSECONDS, "ns") + ); + + private static String prettyPrintElapsed(long nanos) + { + if (nanos == 0) + return "0ns"; + + int count = 0; + StringBuilder builder = new StringBuilder(); + for (OneTimeUnit unit : prettyPrintElapsed) + { + if (count == 2) + break; + + if (nanos >= unit.nanos) + { + if (count > 0) + builder.append(' '); + long inUnit = unit.unit.convert(nanos, TimeUnit.NANOSECONDS); + nanos -= unit.unit.toNanos(inUnit); + builder.append(inUnit); + builder.append(unit.symbol); + ++count; + } else if (count > 0) + ++count; + } + + return builder.toString(); + } + + static String prettyPrintMemory(long size) + { + if (size >= 1000 * 1000 * 1000) + return String.format("%.0fG", size / (double) (1 << 30)); + if (size >= 1000 * 1000) + return String.format("%.0fM", size / (double) (1 << 20)); + return String.format("%.0fK", size / (double) (1 << 10)); + } +} + diff --git a/test/burn/org/apache/cassandra/net/Verifier.java b/test/burn/org/apache/cassandra/net/Verifier.java new file mode 100644 index 000000000000..8b48c9a094f8 --- /dev/null +++ b/test/burn/org/apache/cassandra/net/Verifier.java @@ -0,0 +1,1637 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.net; + +import java.nio.BufferOverflowException; +import java.util.Arrays; +import java.util.Map; +import java.util.concurrent.ConcurrentSkipListMap; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.AtomicReferenceArray; +import java.util.concurrent.locks.LockSupport; +import java.util.function.Consumer; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.carrotsearch.hppc.LongObjectOpenHashMap; +import org.apache.cassandra.net.Verifier.ExpiredMessageEvent.ExpirationType; +import org.apache.cassandra.utils.ApproximateTime; +import org.apache.cassandra.utils.concurrent.WaitQueue; + +import static java.util.concurrent.TimeUnit.*; +import static java.util.concurrent.TimeUnit.NANOSECONDS; +import static org.apache.cassandra.net.MessagingService.VERSION_40; +import static org.apache.cassandra.net.MessagingService.current_version; +import static org.apache.cassandra.net.ConnectionType.LARGE_MESSAGES; +import static org.apache.cassandra.net.OutboundConnection.LargeMessageDelivery.DEFAULT_BUFFER_SIZE; +import static org.apache.cassandra.net.OutboundConnections.LARGE_MESSAGE_THRESHOLD; +import static org.apache.cassandra.net.Verifier.EventCategory.OTHER; +import static org.apache.cassandra.net.Verifier.EventCategory.RECEIVE; +import static org.apache.cassandra.net.Verifier.EventCategory.SEND; +import static org.apache.cassandra.net.Verifier.EventType.ARRIVE; +import static org.apache.cassandra.net.Verifier.EventType.CLOSED_BEFORE_ARRIVAL; +import static org.apache.cassandra.net.Verifier.EventType.DESERIALIZE; +import static org.apache.cassandra.net.Verifier.EventType.ENQUEUE; +import static org.apache.cassandra.net.Verifier.EventType.FAILED_CLOSING; +import static org.apache.cassandra.net.Verifier.EventType.FAILED_DESERIALIZE; +import static org.apache.cassandra.net.Verifier.EventType.FAILED_EXPIRED_ON_SEND; +import static org.apache.cassandra.net.Verifier.EventType.FAILED_EXPIRED_ON_RECEIVE; +import static org.apache.cassandra.net.Verifier.EventType.FAILED_FRAME; +import static org.apache.cassandra.net.Verifier.EventType.FAILED_OVERLOADED; +import static org.apache.cassandra.net.Verifier.EventType.FAILED_SERIALIZE; +import static org.apache.cassandra.net.Verifier.EventType.FINISH_SERIALIZE_LARGE; +import static org.apache.cassandra.net.Verifier.EventType.PROCESS; +import static org.apache.cassandra.net.Verifier.EventType.SEND_FRAME; +import static org.apache.cassandra.net.Verifier.EventType.SENT_FRAME; +import static org.apache.cassandra.net.Verifier.EventType.SERIALIZE; +import static org.apache.cassandra.net.Verifier.ExpiredMessageEvent.ExpirationType.ON_SENT; +import static org.apache.cassandra.utils.MonotonicClock.approxTime; + +/** + * This class is a single-threaded verifier monitoring a single link, with events supplied by inbound and outbound threads + * + * By making verification single threaded, it is easier to reason about (and complex enough as is), but also permits + * a dedicated thread to monitor timeliness of events, e.g. elapsed time between a given SEND and its corresponding RECEIVE + * + * TODO: timeliness of events + * TODO: periodically stop all activity to/from a given endpoint, until it stops (and verify queues all empty, counters all accurate) + * TODO: integrate with proxy that corrupts frames + * TODO: test _OutboundConnection_ close + */ +@SuppressWarnings("WeakerAccess") +public class Verifier +{ + private static final Logger logger = LoggerFactory.getLogger(Verifier.class); + + public enum Destiny + { + SUCCEED, + FAIL_TO_SERIALIZE, + FAIL_TO_DESERIALIZE, + } + + enum EventCategory + { + SEND, RECEIVE, OTHER + } + + enum EventType + { + FAILED_OVERLOADED(SEND), + ENQUEUE(SEND), + FAILED_EXPIRED_ON_SEND(SEND), + SERIALIZE(SEND), + FAILED_SERIALIZE(SEND), + FINISH_SERIALIZE_LARGE(SEND), + SEND_FRAME(SEND), + FAILED_FRAME(SEND), + SENT_FRAME(SEND), + ARRIVE(RECEIVE), + FAILED_EXPIRED_ON_RECEIVE(RECEIVE), + DESERIALIZE(RECEIVE), + CLOSED_BEFORE_ARRIVAL(RECEIVE), + FAILED_DESERIALIZE(RECEIVE), + PROCESS(RECEIVE), + + FAILED_CLOSING(SEND), + + CONNECT_OUTBOUND(OTHER), + SYNC(OTHER), // the connection will stop sending messages, and promptly process any waiting inbound messages + CONTROLLER_UPDATE(OTHER); + + final EventCategory category; + + EventType(EventCategory category) + { + this.category = category; + } + } + + public static class Event + { + final EventType type; + Event(EventType type) + { + this.type = type; + } + } + + static class SimpleEvent extends Event + { + final long at; + SimpleEvent(EventType type, long at) + { + super(type); + this.at = at; + } + public String toString() { return type.toString(); } + } + + static class BoundedEvent extends Event + { + final long start; + volatile long end; + BoundedEvent(EventType type, long start) + { + super(type); + this.start = start; + } + public void complete(Verifier verifier) + { + end = verifier.sequenceId.getAndIncrement(); + verifier.events.put(end, this); + } + } + + static class SimpleMessageEvent extends SimpleEvent + { + final long messageId; + SimpleMessageEvent(EventType type, long at, long messageId) + { + super(type, at); + this.messageId = messageId; + } + } + + static class BoundedMessageEvent extends BoundedEvent + { + final long messageId; + BoundedMessageEvent(EventType type, long start, long messageId) + { + super(type, start); + this.messageId = messageId; + } + } + + static class EnqueueMessageEvent extends BoundedMessageEvent + { + final Message message; + final Destiny destiny; + EnqueueMessageEvent(EventType type, long start, Message message, Destiny destiny) + { + super(type, start, message.id()); + this.message = message; + this.destiny = destiny; + } + public String toString() { return String.format("%s{%s}", type, destiny); } + } + + static class SerializeMessageEvent extends SimpleMessageEvent + { + final int messagingVersion; + SerializeMessageEvent(EventType type, long at, long messageId, int messagingVersion) + { + super(type, at, messageId); + this.messagingVersion = messagingVersion; + } + public String toString() { return String.format("%s{ver=%d}", type, messagingVersion); } + } + + static class SimpleMessageEventWithSize extends SimpleMessageEvent + { + final int messageSize; + SimpleMessageEventWithSize(EventType type, long at, long messageId, int messageSize) + { + super(type, at, messageId); + this.messageSize = messageSize; + } + public String toString() { return String.format("%s{size=%d}", type, messageSize); } + } + + static class FailedSerializeEvent extends SimpleMessageEvent + { + final int bytesWrittenToNetwork; + final Throwable failure; + FailedSerializeEvent(long at, long messageId, int bytesWrittenToNetwork, Throwable failure) + { + super(FAILED_SERIALIZE, at, messageId); + this.bytesWrittenToNetwork = bytesWrittenToNetwork; + this.failure = failure; + } + public String toString() { return String.format("FAILED_SERIALIZE{written=%d, failure=%s}", bytesWrittenToNetwork, failure); } + } + + static class ExpiredMessageEvent extends SimpleMessageEvent + { + enum ExpirationType {ON_SENT, ON_ARRIVED, ON_PROCESSED } + final int messageSize; + final long timeElapsed; + final TimeUnit timeUnit; + final ExpirationType expirationType; + ExpiredMessageEvent(long at, long messageId, int messageSize, long timeElapsed, TimeUnit timeUnit, ExpirationType expirationType) + { + super(expirationType == ON_SENT ? FAILED_EXPIRED_ON_SEND : FAILED_EXPIRED_ON_RECEIVE, at, messageId); + this.messageSize = messageSize; + this.timeElapsed = timeElapsed; + this.timeUnit = timeUnit; + this.expirationType = expirationType; + } + public String toString() { return String.format("EXPIRED_%s{size=%d,elapsed=%d,unit=%s}", expirationType, messageSize, timeElapsed, timeUnit); } + } + + static class FrameEvent extends SimpleEvent + { + final int messageCount; + final int payloadSizeInBytes; + FrameEvent(EventType type, long at, int messageCount, int payloadSizeInBytes) + { + super(type, at); + this.messageCount = messageCount; + this.payloadSizeInBytes = payloadSizeInBytes; + } + } + + static class ProcessMessageEvent extends SimpleMessageEvent + { + final Message message; + ProcessMessageEvent(long at, Message message) + { + super(PROCESS, at, message.id()); + this.message = message; + } + } + + EnqueueMessageEvent onEnqueue(Message message, Destiny destiny) + { + EnqueueMessageEvent enqueue = new EnqueueMessageEvent(ENQUEUE, nextId(), message, destiny); + events.put(enqueue.start, enqueue); + return enqueue; + } + void onOverloaded(long messageId) + { + long at = nextId(); + events.put(at, new SimpleMessageEvent(FAILED_OVERLOADED, at, messageId)); + } + void onFailedClosing(long messageId) + { + long at = nextId(); + events.put(at, new SimpleMessageEvent(FAILED_CLOSING, at, messageId)); + } + void onSerialize(long messageId, int messagingVersion) + { + long at = nextId(); + events.put(at, new SerializeMessageEvent(SERIALIZE, at, messageId, messagingVersion)); + } + void onFinishSerializeLarge(long messageId) + { + long at = nextId(); + events.put(at, new SimpleMessageEvent(FINISH_SERIALIZE_LARGE, at, messageId)); + } + void onFailedSerialize(long messageId, int bytesWrittenToNetwork, Throwable failure) + { + long at = nextId(); + events.put(at, new FailedSerializeEvent(at, messageId, bytesWrittenToNetwork, failure)); + } + void onExpiredBeforeSend(long messageId, int messageSize, long timeElapsed, TimeUnit timeUnit) + { + onExpired(messageId, messageSize, timeElapsed, timeUnit, ON_SENT); + } + void onSendFrame(int messageCount, int payloadSizeInBytes) + { + long at = nextId(); + events.put(at, new FrameEvent(SEND_FRAME, at, messageCount, payloadSizeInBytes)); + } + void onSentFrame(int messageCount, int payloadSizeInBytes) + { + long at = nextId(); + events.put(at, new FrameEvent(SENT_FRAME, at, messageCount, payloadSizeInBytes)); + } + void onFailedFrame(int messageCount, int payloadSizeInBytes) + { + long at = nextId(); + events.put(at, new FrameEvent(FAILED_FRAME, at, messageCount, payloadSizeInBytes)); + } + void onArrived(long messageId, int messageSize) + { + long at = nextId(); + events.put(at, new SimpleMessageEventWithSize(ARRIVE, at, messageId, messageSize)); + } + void onArrivedExpired(long messageId, int messageSize, boolean wasCorrupt, long timeElapsed, TimeUnit timeUnit) + { + onExpired(messageId, messageSize, timeElapsed, timeUnit, ExpirationType.ON_ARRIVED); + } + void onDeserialize(long messageId, int messagingVersion) + { + long at = nextId(); + events.put(at, new SerializeMessageEvent(DESERIALIZE, at, messageId, messagingVersion)); + } + void onClosedBeforeArrival(long messageId, int messageSize) + { + long at = nextId(); + events.put(at, new SimpleMessageEventWithSize(CLOSED_BEFORE_ARRIVAL, at, messageId, messageSize)); + } + void onFailedDeserialize(long messageId, int messageSize) + { + long at = nextId(); + events.put(at, new SimpleMessageEventWithSize(FAILED_DESERIALIZE, at, messageId, messageSize)); + } + void process(Message message) + { + long at = nextId(); + events.put(at, new ProcessMessageEvent(at, message)); + } + void onProcessExpired(long messageId, int messageSize, long timeElapsed, TimeUnit timeUnit) + { + onExpired(messageId, messageSize, timeElapsed, timeUnit, ExpirationType.ON_PROCESSED); + } + private void onExpired(long messageId, int messageSize, long timeElapsed, TimeUnit timeUnit, ExpirationType expirationType) + { + long at = nextId(); + events.put(at, new ExpiredMessageEvent(at, messageId, messageSize, timeElapsed, timeUnit, expirationType)); + } + + + + static class ConnectOutboundEvent extends SimpleEvent + { + final int messagingVersion; + final OutboundConnectionSettings settings; + ConnectOutboundEvent(long at, int messagingVersion, OutboundConnectionSettings settings) + { + super(EventType.CONNECT_OUTBOUND, at); + this.messagingVersion = messagingVersion; + this.settings = settings; + } + } + + // TODO: do we need this? + static class ConnectInboundEvent extends SimpleEvent + { + final int messagingVersion; + final InboundMessageHandler handler; + ConnectInboundEvent(long at, int messagingVersion, InboundMessageHandler handler) + { + super(EventType.CONNECT_OUTBOUND, at); + this.messagingVersion = messagingVersion; + this.handler = handler; + } + } + + static class SyncEvent extends SimpleEvent + { + final Runnable onCompletion; + SyncEvent(long at, Runnable onCompletion) + { + super(EventType.SYNC, at); + this.onCompletion = onCompletion; + } + } + + static class ControllerEvent extends BoundedEvent + { + final long minimumBytesInFlight; + final long maximumBytesInFlight; + ControllerEvent(long start, long minimumBytesInFlight, long maximumBytesInFlight) + { + super(EventType.CONTROLLER_UPDATE, start); + this.minimumBytesInFlight = minimumBytesInFlight; + this.maximumBytesInFlight = maximumBytesInFlight; + } + } + + void onSync(Runnable onCompletion) + { + SyncEvent connect = new SyncEvent(nextId(), onCompletion); + events.put(connect.at, connect); + } + + void onConnectOutbound(int messagingVersion, OutboundConnectionSettings settings) + { + ConnectOutboundEvent connect = new ConnectOutboundEvent(nextId(), messagingVersion, settings); + events.put(connect.at, connect); + } + + void onConnectInbound(int messagingVersion, InboundMessageHandler handler) + { + ConnectInboundEvent connect = new ConnectInboundEvent(nextId(), messagingVersion, handler); + events.put(connect.at, connect); + } + + private final BytesInFlightController controller; + private final AtomicLong sequenceId = new AtomicLong(); + private final EventSequence events = new EventSequence(); + private final InboundMessageHandlers inbound; + private final OutboundConnection outbound; + + Verifier(BytesInFlightController controller, OutboundConnection outbound, InboundMessageHandlers inbound) + { + this.controller = controller; + this.inbound = inbound; + this.outbound = outbound; + } + + private long nextId() + { + return sequenceId.getAndIncrement(); + } + + public void logFailure(String message, Object ... params) + { + fail(message, params); + } + + private void fail(String message, Object ... params) + { + logger.error("{}", String.format(message, params)); + logger.error("Connection: {}", currentConnection); + } + + private void fail(String message, Throwable t, Object ... params) + { + logger.error("{}", String.format(message, params), t); + logger.error("Connection: {}", currentConnection); + } + + private void failinfo(String message, Object ... params) + { + logger.error("{}", String.format(message, params)); + } + + private static class MessageState + { + final Message message; + final Destiny destiny; + int messagingVersion; + // set initially to message.expiresAtNanos, but if at serialization time we use + // an older messaging version we may not be able to serialize expiration + long expiresAtNanos; + long enqueueStart, enqueueEnd, serialize, arrive, deserialize; + boolean processOnEventLoop, processOutOfOrder; + Event sendState, receiveState; + long lastUpdateAt; + long lastUpdateNanos; + ConnectionState sentOn; + boolean doneSend, doneReceive; + + int messageSize() + { + return message.serializedSize(messagingVersion); + } + + MessageState(Message message, Destiny destiny, long enqueueStart) + { + this.message = message; + this.destiny = destiny; + this.enqueueStart = enqueueStart; + this.expiresAtNanos = message.expiresAtNanos(); + } + + void update(SimpleEvent event, long now) + { + update(event, event.at, now); + } + void update(Event event, long at, long now) + { + lastUpdateAt = at; + lastUpdateNanos = now; + switch (event.type.category) + { + case SEND: + sendState = event; + break; + case RECEIVE: + receiveState = event; + break; + default: throw new IllegalStateException(); + } + } + + boolean is(EventType type) + { + switch (type.category) + { + case SEND: return sendState != null && sendState.type == type; + case RECEIVE: return receiveState != null && receiveState.type == type; + default: return false; + } + } + + boolean is(EventType type1, EventType type2) + { + return is(type1) || is(type2); + } + + boolean is(EventType type1, EventType type2, EventType type3) + { + return is(type1) || is(type2) || is(type3); + } + + void require(EventType event, Verifier verifier, EventType type) + { + if (!is(type)) + verifier.fail("Invalid state at %s for %s: expected %s", event, this, type); + } + + void require(EventType event, Verifier verifier, EventType type1, EventType type2) + { + if (!is(type1) && !is(type2)) + verifier.fail("Invalid state at %s for %s: expected %s or %s", event, this, type1, type2); + } + + void require(EventType event, Verifier verifier, EventType type1, EventType type2, EventType type3) + { + if (!is(type1) && !is(type2) && !is(type3)) + verifier.fail("Invalid state %s for %s: expected %s, %s or %s", event, this, type1, type2, type3); + } + + public String toString() + { + return String.format("{id:%d, state:[%s,%s], upd:%d, ver:%d, enqueue:[%d,%d], ser:%d, arr:%d, deser:%d, expires:%d, sentOn: %d}", + message.id(), sendState, receiveState, lastUpdateAt, messagingVersion, enqueueStart, enqueueEnd, serialize, arrive, deserialize, approxTime.translate().toMillisSinceEpoch(expiresAtNanos), sentOn == null ? -1 : sentOn.connectionId); + } + } + + private final LongObjectOpenHashMap messages = new LongObjectOpenHashMap<>(); + + // messages start here, but may enter in a haphazard (non-sequential) fashion; + // ENQUEUE_START, ENQUEUE_END both take place here, with the latter imposing bounds on the out-of-order appearance of messages. + // note that ENQUEUE_END - being concurrent - may not appear before the message's lifespan has completely ended. + private final Queue enqueueing = new Queue<>(); + + static final class ConnectionState + { + final long connectionId; + final int messagingVersion; + // Strict message order will then be determined at serialization time, since this happens on a single thread. + // The order in which messages arrive here determines the order they will arrive on the other node. + // must follow either ENQUEUE_START or ENQUEUE_END + final Queue serializing = new Queue<>(); + + // Messages sent on the small connection will all be sent in frames; this is a concurrent operation, + // so only the sendingFrame MUST be encountered before any future events - + // large connections skip this step and goes straight to arriving + // we consult the queues in reverse order in arriving, as it is acceptable to find our frame in any of these queues + final FramesInFlight framesInFlight = new FramesInFlight(); // unknown if the messages will arrive, accept either + + // for large messages OR < VERSION_40, arriving can occur BEFORE serializing completes successfully + // OR a frame is fully serialized + final Queue arriving = new Queue<>(); + + final Queue deserializingOnEventLoop = new Queue<>(), + deserializingOffEventLoop = new Queue<>(); + + InboundMessageHandler inbound; + + // TODO + long sentCount, sentBytes; + long receivedCount, receivedBytes; + + ConnectionState(long connectionId, int messagingVersion) + { + this.connectionId = connectionId; + this.messagingVersion = messagingVersion; + } + + public String toString() + { + return String.format("{id: %d, ver: %d, ser: %d, inFlight: %s, arriving: %d, deserOn: %d, deserOff: %d}", + connectionId, messagingVersion, serializing.size(), framesInFlight, arriving.size(), deserializingOnEventLoop.size(), deserializingOffEventLoop.size()); + } + } + + private final Queue processingOutOfOrder = new Queue<>(); + + private SyncEvent sync; + private long nextMessageId = 0; + private long now; + private long connectionCounter; + private ConnectionState currentConnection = new ConnectionState(connectionCounter++, current_version); + + private long outboundSentCount, outboundSentBytes; + private long outboundSubmittedCount; + private long outboundOverloadedCount, outboundOverloadedBytes; + private long outboundExpiredCount, outboundExpiredBytes; + private long outboundErrorCount, outboundErrorBytes; + + public void run(Runnable onFailure, long deadlineNanos) + { + try + { + long lastEventAt = approxTime.now(); + while ((now = approxTime.now()) < deadlineNanos) + { + Event next = events.await(nextMessageId, 100L, MILLISECONDS); + if (next == null) + { + // decide if we have any messages waiting too long to proceed + while (!processingOutOfOrder.isEmpty()) + { + MessageState m = processingOutOfOrder.get(0); + if (now - m.lastUpdateNanos > SECONDS.toNanos(10L)) + { + fail("Unreasonably long period spent waiting for out-of-order deser/delivery of received message %d", m.message.id()); + MessageState v = maybeRemove(m.message.id(), PROCESS); + controller.fail(v.message.serializedSize(v.messagingVersion == 0 ? current_version : v.messagingVersion)); + processingOutOfOrder.remove(0); + } + else break; + } + + if (sync != null) + { + // if we have waited 100ms since beginning a sync, with no events, and ANY of our queues are + // non-empty, something is probably wrong; however, let's give ourselves a little bit longer + + boolean done = + currentConnection.serializing.isEmpty() + && currentConnection.arriving.isEmpty() + && currentConnection.deserializingOnEventLoop.isEmpty() + && currentConnection.deserializingOffEventLoop.isEmpty() + && currentConnection.framesInFlight.isEmpty() + && enqueueing.isEmpty() + && processingOutOfOrder.isEmpty() + && messages.isEmpty() + && controller.inFlight() == 0; + + //outbound.pendingCount() > 0 ? 5L : 2L + if (!done && now - lastEventAt > SECONDS.toNanos(5L)) + { + // TODO: even 2s or 5s are unreasonable periods of time without _any_ movement on a message waiting to arrive + // this seems to happen regularly on MacOS, but we should confirm this does not happen on Linux + fail("Unreasonably long period spent waiting for sync (%dms)", NANOSECONDS.toMillis(now - lastEventAt)); + messages.forEach((k, v) -> { + failinfo("%s", v); + controller.fail(v.message.serializedSize(v.messagingVersion == 0 ? current_version : v.messagingVersion)); + }); + currentConnection.serializing.clear(); + currentConnection.arriving.clear(); + currentConnection.deserializingOnEventLoop.clear(); + currentConnection.deserializingOffEventLoop.clear(); + enqueueing.clear(); + processingOutOfOrder.clear(); + messages.clear(); + while (!currentConnection.framesInFlight.isEmpty()) + currentConnection.framesInFlight.poll(); + done = true; + } + + if (done) + { + ConnectionUtils.check(outbound) + .pending(0, 0) + .error(outboundErrorCount, outboundErrorBytes) + .submitted(outboundSubmittedCount) + .expired(outboundExpiredCount, outboundExpiredBytes) + .overload(outboundOverloadedCount, outboundOverloadedBytes) + .sent(outboundSentCount, outboundSentBytes) + .check((message, expect, actual) -> fail("%s: expect %d, actual %d", message, expect, actual)); + + sync.onCompletion.run(); + sync = null; + } + } + continue; + } + events.clear(nextMessageId); // TODO: simplify collection if we end up using it exclusively as a queue, as we are now + lastEventAt = now; + + switch (next.type) + { + case ENQUEUE: + { + MessageState m; + EnqueueMessageEvent e = (EnqueueMessageEvent) next; + assert nextMessageId == e.start || nextMessageId == e.end; + assert e.message != null; + if (nextMessageId == e.start) + { + if (sync != null) + fail("Sync in progress - there should be no messages beginning to enqueue"); + + m = new MessageState(e.message, e.destiny, e.start); + messages.put(e.messageId, m); + enqueueing.add(m); + m.update(e, e.start, now); + } + else + { + // warning: enqueueEnd can occur at any time in the future, since it's a different thread; + // it could be arbitrarily paused, long enough even for the messsage to be fully processed + m = messages.get(e.messageId); + if (m != null) + m.enqueueEnd = e.end; + outboundSubmittedCount += 1; + } + break; + } + case FAILED_OVERLOADED: + { + // TODO: verify that we could have exceeded our memory limits + SimpleMessageEvent e = (SimpleMessageEvent) next; + assert nextMessageId == e.at; + MessageState m = remove(e.messageId, enqueueing, messages); + m.require(FAILED_OVERLOADED, this, ENQUEUE); + outboundOverloadedBytes += m.message.serializedSize(current_version); + outboundOverloadedCount += 1; + break; + } + case FAILED_CLOSING: + { + // TODO: verify if this is acceptable due to e.g. inbound refusing to process for long enough + SimpleMessageEvent e = (SimpleMessageEvent) next; + assert nextMessageId == e.at; + MessageState m = messages.remove(e.messageId); // definitely cannot have been sent (in theory) + enqueueing.remove(m); + m.require(FAILED_CLOSING, this, ENQUEUE); + fail("Invalid discard of %d: connection was closing for too long", m.message.id()); + break; + } + case SERIALIZE: + { + // serialize happens serially, so we can compress the asynchronicity of the above enqueue + // into a linear sequence of events we expect to occur on arrival + SerializeMessageEvent e = (SerializeMessageEvent) next; + assert nextMessageId == e.at; + MessageState m = get(e); + assert m.is(ENQUEUE); + m.serialize = e.at; + m.messagingVersion = e.messagingVersion; + if (current_version != e.messagingVersion) + controller.adjust(m.message.serializedSize(current_version), m.message.serializedSize(e.messagingVersion)); + + m.processOnEventLoop = willProcessOnEventLoop(outbound.type(), m.message, e.messagingVersion); + m.expiresAtNanos = expiresAtNanos(m.message, e.messagingVersion); + int mi = enqueueing.indexOf(m); + for (int i = 0 ; i < mi ; ++i) + { + MessageState pm = enqueueing.get(i); + if (pm.enqueueEnd != 0 && pm.enqueueEnd < m.enqueueStart) + { + fail("Invalid order of events: %s enqueued strictly before %s, but serialized after", + pm, m); + } + } + enqueueing.remove(mi); + m.sentOn = currentConnection; + currentConnection.serializing.add(m); + m.update(e, now); + break; + } + case FINISH_SERIALIZE_LARGE: + { + // serialize happens serially, so we can compress the asynchronicity of the above enqueue + // into a linear sequence of events we expect to occur on arrival + SimpleMessageEvent e = (SimpleMessageEvent) next; + assert nextMessageId == e.at; + MessageState m = maybeRemove(e); + outboundSentBytes += m.messageSize(); + outboundSentCount += 1; + m.sentOn.serializing.remove(m); + m.update(e, now); + break; + } + case FAILED_SERIALIZE: + { + FailedSerializeEvent e = (FailedSerializeEvent) next; + assert nextMessageId == e.at; + MessageState m = maybeRemove(e); + + if (outbound.type() == LARGE_MESSAGES) + assert e.failure instanceof InvalidSerializedSizeException || e.failure instanceof Connection.IntentionalIOException || e.failure instanceof Connection.IntentionalRuntimeException; + else + assert e.failure instanceof InvalidSerializedSizeException || e.failure instanceof Connection.IntentionalIOException || e.failure instanceof Connection.IntentionalRuntimeException || e.failure instanceof BufferOverflowException; + + if (e.bytesWrittenToNetwork == 0) // TODO: use header size + messages.remove(m.message.id()); + + InvalidSerializedSizeException ex; + if (outbound.type() != LARGE_MESSAGES + || !(e.failure instanceof InvalidSerializedSizeException) + || ((ex = (InvalidSerializedSizeException) e.failure).expectedSize <= DEFAULT_BUFFER_SIZE && ex.actualSizeAtLeast <= DEFAULT_BUFFER_SIZE) + || (ex.expectedSize > DEFAULT_BUFFER_SIZE && ex.actualSizeAtLeast < DEFAULT_BUFFER_SIZE)) + { + assert e.bytesWrittenToNetwork == 0; + } + + m.require(FAILED_SERIALIZE, this, SERIALIZE); + m.sentOn.serializing.remove(m); + if (m.destiny != Destiny.FAIL_TO_SERIALIZE) + fail("%s failed to serialize, but its destiny was to %s", m, m.destiny); + outboundErrorBytes += m.messageSize(); + outboundErrorCount += 1; + m.update(e, now); + break; + } + case SEND_FRAME: + { + FrameEvent e = (FrameEvent) next; + assert nextMessageId == e.at; + int size = 0; + Frame frame = new Frame(); + MessageState first = currentConnection.serializing.get(0); + int messagingVersion = first.messagingVersion; + for (int i = 0 ; i < e.messageCount ; ++i) + { + MessageState m = currentConnection.serializing.get(i); + size += m.message.serializedSize(m.messagingVersion); + if (m.messagingVersion != messagingVersion) + { + fail("Invalid sequence of events: %s encoded to same frame as %s", + m, first); + } + + frame.add(m); + m.update(e, now); + assert !m.doneSend; + m.doneSend = true; + if (m.doneReceive) + messages.remove(m.message.id()); + } + frame.payloadSizeInBytes = e.payloadSizeInBytes; + frame.messageCount = e.messageCount; + frame.messagingVersion = messagingVersion; + currentConnection.framesInFlight.add(frame); + currentConnection.serializing.removeFirst(e.messageCount); + if (e.payloadSizeInBytes != size) + fail("Invalid frame payload size with %s: expected %d, actual %d", first, size, e.payloadSizeInBytes); + break; + } + case SENT_FRAME: + { + Frame frame = currentConnection.framesInFlight.supplySendStatus(Frame.Status.SUCCESS); + frame.forEach(m -> m.update((SimpleEvent) next, now)); + + outboundSentBytes += frame.payloadSizeInBytes; + outboundSentCount += frame.messageCount; + break; + } + case FAILED_FRAME: + { + // TODO: is it possible for this to be signalled AFTER our reconnect event? probably, in which case this will be wrong + // TODO: verify that this was expected + Frame frame = currentConnection.framesInFlight.supplySendStatus(Frame.Status.FAILED); + frame.forEach(m -> m.update((SimpleEvent) next, now)); + if (frame.messagingVersion >= VERSION_40) + { + // the contents cannot be delivered without the whole frame arriving, so clear the contents now + clear(frame, messages); + currentConnection.framesInFlight.remove(frame); + } + outboundErrorBytes += frame.payloadSizeInBytes; + outboundErrorCount += frame.messageCount; + break; + } + case ARRIVE: + { + SimpleMessageEventWithSize e = (SimpleMessageEventWithSize) next; + assert nextMessageId == e.at; + MessageState m = get(e); + + m.arrive = e.at; + if (e.messageSize != m.messageSize()) + fail("onArrived with invalid size for %s: %d vs %d", m, e.messageSize, m.messageSize()); + + if (outbound.type() == LARGE_MESSAGES) + { + m.require(ARRIVE, this, SERIALIZE, FAILED_SERIALIZE, FINISH_SERIALIZE_LARGE); + } + else + { + if (!m.is(SEND_FRAME, SENT_FRAME)) + { + fail("Invalid order of events: %s arrived before being sent in a frame", m); + break; + } + + int fi = -1, mi = -1; + while (fi + 1 < m.sentOn.framesInFlight.size() && mi < 0) + mi = m.sentOn.framesInFlight.get(++fi).indexOf(m); + + if (fi == m.sentOn.framesInFlight.size()) + { + fail("Invalid state: %s, but no frame in flight was found to contain it", m); + break; + } + + if (fi > 0) + { + // we have skipped over some frames, meaning these have either failed (and we know it) + // or we have not yet heard about them and they have presumably failed, or something + // has gone wrong + fail("BEGIN: Successfully sent frames were not delivered"); + for (int i = 0 ; i < fi ; ++i) + { + Frame skip = m.sentOn.framesInFlight.get(i); + skip.receiveStatus = Frame.Status.FAILED; + if (skip.sendStatus == Frame.Status.SUCCESS) + { + failinfo("Frame %s", skip); + for (int j = 0 ; j < skip.size() ; ++j) + failinfo("Containing: %s", skip.get(j)); + } + clear(skip, messages); + } + m.sentOn.framesInFlight.removeFirst(fi); + failinfo("END: Successfully sent frames were not delivered"); + } + + Frame frame = m.sentOn.framesInFlight.get(0); + for (int i = 0; i < mi; ++i) + fail("Invalid order of events: %s serialized strictly before %s, but arrived after", frame.get(i), m); + + frame.remove(mi); + if (frame.isEmpty()) + m.sentOn.framesInFlight.poll(); + } + m.sentOn.arriving.add(m); + m.update(e, now); + break; + } + case DESERIALIZE: + { + // deserialize may happen in parallel for large messages, but in sequence for small messages + // we currently require that this event be issued before any possible error is thrown + SimpleMessageEvent e = (SimpleMessageEvent) next; + assert nextMessageId == e.at; + MessageState m = get(e); + m.require(DESERIALIZE, this, ARRIVE); + m.deserialize = e.at; + // deserialize may be off-loaded, so we can only impose meaningful ordering constraints + // on those messages we know to have been processed on the event loop + int mi = m.sentOn.arriving.indexOf(m); + if (m.processOnEventLoop) + { + for (int i = 0 ; i < mi ; ++i) + { + MessageState pm = m.sentOn.arriving.get(i); + if (pm.processOnEventLoop) + { + fail("Invalid order of events: %d (%d, %d) arrived strictly before %d (%d, %d), but deserialized after", + pm.message.id(), pm.arrive, pm.deserialize, m.message.id(), m.arrive, m.deserialize); + } + } + m.sentOn.deserializingOnEventLoop.add(m); + } + else + { + m.sentOn.deserializingOffEventLoop.add(m); + } + m.sentOn.arriving.remove(mi); + m.update(e, now); + break; + } + case CLOSED_BEFORE_ARRIVAL: + { + SimpleMessageEventWithSize e = (SimpleMessageEventWithSize) next; + assert nextMessageId == e.at; + MessageState m = maybeRemove(e); + + if (e.messageSize != m.messageSize()) + fail("onClosedBeforeArrival has invalid size for %s: %d vs %d", m, e.messageSize, m.messageSize()); + + m.sentOn.deserializingOffEventLoop.remove(m); + if (m.destiny == Destiny.FAIL_TO_SERIALIZE && outbound.type() == LARGE_MESSAGES) + break; + fail("%s closed before arrival, but its destiny was to %s", m, m.destiny); + break; + } + case FAILED_DESERIALIZE: + { + SimpleMessageEventWithSize e = (SimpleMessageEventWithSize) next; + assert nextMessageId == e.at; + MessageState m = maybeRemove(e); + + if (e.messageSize != m.messageSize()) + fail("onFailedDeserialize has invalid size for %s: %d vs %d", m, e.messageSize, m.messageSize()); + m.require(FAILED_DESERIALIZE, this, ARRIVE, DESERIALIZE); + (m.processOnEventLoop ? m.sentOn.deserializingOnEventLoop : m.sentOn.deserializingOffEventLoop).remove(m); + switch (m.destiny) + { + case FAIL_TO_DESERIALIZE: + break; + case FAIL_TO_SERIALIZE: + if (outbound.type() == LARGE_MESSAGES) + break; + default: + fail("%s failed to deserialize, but its destiny was to %s", m, m.destiny); + } + break; + } + case PROCESS: + { + ProcessMessageEvent e = (ProcessMessageEvent) next; + assert nextMessageId == e.at; + MessageState m = maybeRemove(e); + + m.require(PROCESS, this, DESERIALIZE); + if (!Arrays.equals((byte[]) e.message.payload, (byte[]) m.message.payload)) + { + fail("Invalid message payload for %d: %s supplied by processor, but %s implied by original message and messaging version", + e.messageId, Arrays.toString((byte[]) e.message.payload), Arrays.toString((byte[]) m.message.payload)); + } + + if (m.processOutOfOrder) + { + assert !m.processOnEventLoop; // will have already been reported small (processOnEventLoop) messages + processingOutOfOrder.remove(m); + } + else if (m.processOnEventLoop) + { + // we can expect that processing happens sequentially in this case, more specifically + // we can actually expect that this event will occur _immediately_ after the deserialize event + // so that we have exactly one mess + // c + int mi = m.sentOn.deserializingOnEventLoop.indexOf(m); + for (int i = 0 ; i < mi ; ++i) + { + MessageState pm = m.sentOn.deserializingOnEventLoop.get(i); + fail("Invalid order of events: %s deserialized strictly before %s, but processed after", + pm, m); + } + clearFirst(mi, m.sentOn.deserializingOnEventLoop, messages); + m.sentOn.deserializingOnEventLoop.poll(); + } + else + { + int mi = m.sentOn.deserializingOffEventLoop.indexOf(m); + // process may be off-loaded, so we can only impose meaningful ordering constraints + // on those messages we know to have been processed on the event loop + for (int i = 0 ; i < mi ; ++i) + { + MessageState pm = m.sentOn.deserializingOffEventLoop.get(i); + pm.processOutOfOrder = true; + processingOutOfOrder.add(pm); + } + m.sentOn.deserializingOffEventLoop.removeFirst(mi + 1); + } + // this message has been fully validated + break; + } + case FAILED_EXPIRED_ON_SEND: + case FAILED_EXPIRED_ON_RECEIVE: + { + ExpiredMessageEvent e = (ExpiredMessageEvent) next; + assert nextMessageId == e.at; + MessageState m; + switch (e.expirationType) + { + case ON_SENT: + { + m = messages.remove(e.messageId); + m.require(e.type, this, ENQUEUE); + outboundExpiredBytes += m.message.serializedSize(current_version); + outboundExpiredCount += 1; + messages.remove(m.message.id()); + break; + } + case ON_ARRIVED: + m = maybeRemove(e); + if (!m.is(ARRIVE)) + { + if (outbound.type() != LARGE_MESSAGES) m.require(e.type, this, SEND_FRAME, SENT_FRAME, FAILED_FRAME); + else m.require(e.type, this, SERIALIZE, FAILED_SERIALIZE, FINISH_SERIALIZE_LARGE); + } + break; + case ON_PROCESSED: + m = maybeRemove(e); + m.require(e.type, this, DESERIALIZE); + break; + default: + throw new IllegalStateException(); + } + + now = System.nanoTime(); + if (m.expiresAtNanos > now) + { + // we fix the conversion AlmostSameTime for an entire run, which should suffice to guarantee these comparisons + fail("Invalid expiry of %d: expiry should occur in %dms; event believes %dms have elapsed, and %dms have actually elapsed", m.message.id(), + NANOSECONDS.toMillis(m.expiresAtNanos - m.message.createdAtNanos()), + e.timeUnit.toMillis(e.timeElapsed), + NANOSECONDS.toMillis(now - m.message.createdAtNanos())); + } + + switch (e.expirationType) + { + case ON_SENT: + enqueueing.remove(m); + break; + case ON_ARRIVED: + if (m.is(ARRIVE)) + m.sentOn.arriving.remove(m); + switch (m.sendState.type) + { + case SEND_FRAME: + case SENT_FRAME: + case FAILED_FRAME: + // TODO: this should be robust to re-ordering; should perhaps extract a common method + m.sentOn.framesInFlight.get(0).remove(m); + if (m.sentOn.framesInFlight.get(0).isEmpty()) + m.sentOn.framesInFlight.poll(); + break; + } + break; + case ON_PROCESSED: + (m.processOnEventLoop ? m.sentOn.deserializingOnEventLoop : m.sentOn.deserializingOffEventLoop).remove(m); + break; + } + + if (m.messagingVersion != 0 && e.messageSize != m.messageSize()) + fail("onExpired %s with invalid size for %s: %d vs %d", e.expirationType, m, e.messageSize, m.messageSize()); + + break; + } + case CONTROLLER_UPDATE: + { + break; + } + case CONNECT_OUTBOUND: + { + ConnectOutboundEvent e = (ConnectOutboundEvent) next; + currentConnection = new ConnectionState(connectionCounter++, e.messagingVersion); + break; + } + case SYNC: + { + sync = (SyncEvent) next; + break; + } + default: + throw new IllegalStateException(); + } + ++nextMessageId; + } + } + catch (InterruptedException e) + { + } + catch (Throwable t) + { + logger.error("Unexpected error:", t); + onFailure.run(); + } + } + + private MessageState get(SimpleMessageEvent onEvent) + { + MessageState m = messages.get(onEvent.messageId); + if (m == null) + throw new IllegalStateException("Missing " + onEvent + ": " + onEvent.messageId); + return m; + } + private MessageState maybeRemove(SimpleMessageEvent onEvent) + { + return maybeRemove(onEvent.messageId, onEvent.type, onEvent); + } + private MessageState maybeRemove(long messageId, EventType onEvent) + { + return maybeRemove(messageId, onEvent, onEvent); + } + private MessageState maybeRemove(long messageId, EventType onEvent, Object id) + { + MessageState m = messages.get(messageId); + if (m == null) + throw new IllegalStateException("Missing " + id + ": " + messageId); + switch (onEvent.category) + { + case SEND: + if (m.doneSend) + fail("%s already doneSend %s", onEvent, m); + m.doneSend = true; + if (m.doneReceive) messages.remove(messageId); + break; + case RECEIVE: + if (m.doneReceive) + fail("%s already doneReceive %s", onEvent, m); + m.doneReceive = true; + if (m.doneSend) messages.remove(messageId); + } + return m; + } + + + private static class Frame extends Queue + { + enum Status { SUCCESS, FAILED, UNKNOWN } + Status sendStatus = Status.UNKNOWN, receiveStatus = Status.UNKNOWN; + int messagingVersion; + int messageCount; + int payloadSizeInBytes; + + public String toString() + { + return String.format("{count:%d, size:%d, version:%d, send:%s, receive:%s}", + messageCount, payloadSizeInBytes, messagingVersion, sendStatus, receiveStatus); + } + } + + private static MessageState remove(long messageId, Queue queue, LongObjectOpenHashMap lookup) + { + MessageState m = lookup.remove(messageId); + queue.remove(m); + return m; + } + + private static void clearFirst(int count, Queue queue, LongObjectOpenHashMap lookup) + { + if (count > 0) + { + for (int i = 0 ; i < count ; ++i) + lookup.remove(queue.get(i).message.id()); + queue.removeFirst(count); + } + } + + private static void clear(Queue queue, LongObjectOpenHashMap lookup) + { + if (!queue.isEmpty()) + clearFirst(queue.size(), queue, lookup); + } + + private static class EventSequence + { + static final int CHUNK_SIZE = 1 << 10; + static class Chunk extends AtomicReferenceArray + { + final long sequenceId; + int removed = 0; + Chunk(long sequenceId) + { + super(CHUNK_SIZE); + this.sequenceId = sequenceId; + } + Event get(long sequenceId) + { + return get((int)(sequenceId - this.sequenceId)); + } + void set(long sequenceId, Event event) + { + lazySet((int)(sequenceId - this.sequenceId), event); + } + } + + // we use a concurrent skip list to permit efficient searching, even if we always append + final ConcurrentSkipListMap chunkList = new ConcurrentSkipListMap<>(); + final WaitQueue writerWaiting = new WaitQueue(); + + volatile Chunk writerChunk = new Chunk(0); + Chunk readerChunk = writerChunk; + + long readerWaitingFor; + volatile Thread readerWaiting; + + EventSequence() + { + chunkList.put(0L, writerChunk); + } + + public void put(long sequenceId, Event event) + { + long chunkSequenceId = sequenceId & -CHUNK_SIZE; + Chunk chunk = writerChunk; + if (chunk.sequenceId != chunkSequenceId) + { + try + { + chunk = ensureChunk(chunkSequenceId); + } + catch (InterruptedException e) + { + throw new RuntimeException(e); + } + } + + chunk.set(sequenceId, event); + + Thread wake = readerWaiting; + long wakeIf = readerWaitingFor; // we are guarded by the above volatile read + if (wake != null && wakeIf == sequenceId) + LockSupport.unpark(wake); + } + + Chunk ensureChunk(long chunkSequenceId) throws InterruptedException + { + Chunk chunk = chunkList.get(chunkSequenceId); + if (chunk == null) + { + Map.Entry e; + while ( null != (e = chunkList.firstEntry()) && chunkSequenceId - e.getKey() > 1 << 12) + { + WaitQueue.Signal signal = writerWaiting.register(); + if (null != (e = chunkList.firstEntry()) && chunkSequenceId - e.getKey() > 1 << 12) + signal.await(); + else + signal.cancel(); + } + chunk = chunkList.get(chunkSequenceId); + if (chunk == null) + { + synchronized (this) + { + chunk = chunkList.get(chunkSequenceId); + if (chunk == null) + chunkList.put(chunkSequenceId, chunk = new Chunk(chunkSequenceId)); + } + } + } + return chunk; + } + + Chunk readerChunk(long readerId) throws InterruptedException + { + long chunkSequenceId = readerId & -CHUNK_SIZE; + if (readerChunk.sequenceId != chunkSequenceId) + readerChunk = ensureChunk(chunkSequenceId); + return readerChunk; + } + + public Event await(long id, long timeout, TimeUnit unit) throws InterruptedException + { + return await(id, System.nanoTime() + unit.toNanos(timeout)); + } + + public Event await(long id, long deadlineNanos) throws InterruptedException + { + Chunk chunk = readerChunk(id); + Event result = chunk.get(id); + if (result != null) + return result; + + readerWaitingFor = id; + readerWaiting = Thread.currentThread(); + while (null == (result = chunk.get(id))) + { + long waitNanos = deadlineNanos - System.nanoTime(); + if (waitNanos <= 0) + return null; + LockSupport.parkNanos(waitNanos); + if (Thread.interrupted()) + throw new InterruptedException(); + } + readerWaitingFor = -1; + readerWaiting = null; + return result; + } + + public Event find(long sequenceId) + { + long chunkSequenceId = sequenceId & -CHUNK_SIZE; + Chunk chunk = readerChunk; + if (chunk.sequenceId != chunkSequenceId) + { + chunk = writerChunk; + if (chunk.sequenceId != chunkSequenceId) + chunk = chunkList.get(chunkSequenceId); + } + return chunk.get(sequenceId); + } + + public void clear(long sequenceId) + { + long chunkSequenceId = sequenceId & -CHUNK_SIZE; + Chunk chunk = chunkList.get(chunkSequenceId); + chunk.set(sequenceId, null); + if (++chunk.removed == CHUNK_SIZE) + { + chunkList.remove(chunkSequenceId); + writerWaiting.signalAll(); + } + } + } + + static class Queue + { + private Object[] items = new Object[10]; + private int begin, end; + + int size() + { + return end - begin; + } + + T get(int i) + { + //noinspection unchecked + return (T) items[i + begin]; + } + + int indexOf(T item) + { + for (int i = begin ; i < end ; ++i) + { + if (item == items[i]) + return i - begin; + } + return -1; + } + + void remove(T item) + { + int i = indexOf(item); + if (i >= 0) + remove(i); + } + + void remove(int i) + { + i += begin; + assert i < end; + + if (i == begin || i + 1 == end) + { + items[i] = null; + if (begin + 1 == end) begin = end = 0; + else if (i == begin) ++begin; + else --end; + } + else if (i - begin < end - i) + { + System.arraycopy(items, begin, items, begin + 1, i - begin); + items[begin++] = null; + } + else + { + System.arraycopy(items, i + 1, items, i, (end - 1) - i); + items[--end] = null; + } + } + + void add(T item) + { + if (end == items.length) + { + Object[] src = items; + Object[] trg; + if (end - begin < src.length / 2) trg = src; + else trg = new Object[src.length * 2]; + System.arraycopy(src, begin, trg, 0, end - begin); + end -= begin; + begin = 0; + items = trg; + } + items[end++] = item; + } + + void clear() + { + Arrays.fill(items, begin, end, null); + begin = end = 0; + } + + void removeFirst(int count) + { + Arrays.fill(items, begin, begin + count, null); + begin += count; + if (begin == end) + begin = end = 0; + } + + T poll() + { + if (begin == end) + return null; + //noinspection unchecked + T result = (T) items[begin]; + items[begin++] = null; + if (begin == end) + begin = end = 0; + return result; + } + + void forEach(Consumer consumer) + { + for (int i = 0 ; i < size() ; ++i) + consumer.accept(get(i)); + } + + boolean isEmpty() + { + return begin == end; + } + + public String toString() + { + StringBuilder result = new StringBuilder(); + result.append('['); + toString(result); + result.append(']'); + return result.toString(); + } + + void toString(StringBuilder out) + { + for (int i = 0 ; i < size() ; ++i) + { + if (i > 0) out.append(", "); + out.append(get(i)); + } + } + } + + + + static class FramesInFlight + { + // this may be negative, indicating we have processed a frame whose status we did not know at the time + // TODO: we should verify the status of these frames by logging the inferred status and verifying it matches + final Queue inFlight = new Queue<>(); + final Queue retiredWithoutStatus = new Queue<>(); + private int withStatus; + + Frame supplySendStatus(Frame.Status status) + { + Frame frame; + if (withStatus >= 0) frame = inFlight.get(withStatus); + else frame = retiredWithoutStatus.poll(); + assert frame.sendStatus == Frame.Status.UNKNOWN; + frame.sendStatus = status; + ++withStatus; + return frame; + } + + boolean isEmpty() + { + return inFlight.isEmpty(); + } + + int size() + { + return inFlight.size(); + } + + Frame get(int i) + { + return inFlight.get(i); + } + + void add(Frame frame) + { + assert frame.sendStatus == Frame.Status.UNKNOWN; + inFlight.add(frame); + } + + void remove(Frame frame) + { + int i = inFlight.indexOf(frame); + if (i > 0) throw new IllegalStateException(); + if (i == 0) poll(); + } + + void removeFirst(int count) + { + while (count-- > 0) + poll(); + } + + Frame poll() + { + Frame frame = inFlight.poll(); + if (--withStatus < 0) + { + assert frame.sendStatus == Frame.Status.UNKNOWN; + retiredWithoutStatus.add(frame); + } + else + assert frame.sendStatus != Frame.Status.UNKNOWN; + return frame; + } + + public String toString() + { + StringBuilder result = new StringBuilder(); + result.append("[withStatus="); + result.append(withStatus); + result.append("; "); + inFlight.toString(result); + result.append("; "); + retiredWithoutStatus.toString(result); + result.append(']'); + return result.toString(); + } + } + + private static boolean willProcessOnEventLoop(ConnectionType type, Message message, int messagingVersion) + { + int size = message.serializedSize(messagingVersion); + if (type == ConnectionType.SMALL_MESSAGES && messagingVersion >= VERSION_40) + return size <= LARGE_MESSAGE_THRESHOLD; + else if (messagingVersion >= VERSION_40) + return size <= DEFAULT_BUFFER_SIZE; + else + return size <= LARGE_MESSAGE_THRESHOLD; + } + + private static long expiresAtNanos(Message message, int messagingVersion) + { + return messagingVersion < VERSION_40 ? message.verb().expiresAtNanos(message.createdAtNanos()) + : message.expiresAtNanos(); + } + +} diff --git a/test/burn/org/apache/cassandra/utils/memory/LongBufferPoolTest.java b/test/burn/org/apache/cassandra/utils/memory/LongBufferPoolTest.java index 57aa940abd94..838038a6756b 100644 --- a/test/burn/org/apache/cassandra/utils/memory/LongBufferPoolTest.java +++ b/test/burn/org/apache/cassandra/utils/memory/LongBufferPoolTest.java @@ -33,6 +33,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.apache.cassandra.concurrent.NamedThreadFactory; import org.apache.cassandra.config.DatabaseDescriptor; import org.apache.cassandra.utils.DynamicList; @@ -68,6 +69,42 @@ public class LongBufferPoolTest private static final int STDEV_BUFFER_SIZE = 10 << 10; // picked to ensure exceeding buffer size is rare, but occurs private static final DateFormat DATE_FORMAT = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss"); + static final class Debug implements BufferPool.Debug + { + static class DebugChunk + { + volatile long lastRecycled; + static DebugChunk get(BufferPool.Chunk chunk) + { + if (chunk.debugAttachment == null) + chunk.debugAttachment = new DebugChunk(); + return (DebugChunk) chunk.debugAttachment; + } + } + long recycleRound = 1; + final List normalChunks = new ArrayList<>(); + final List tinyChunks = new ArrayList<>(); + public synchronized void registerNormal(BufferPool.Chunk chunk) + { + chunk.debugAttachment = new DebugChunk(); + normalChunks.add(chunk); + } + public void recycleNormal(BufferPool.Chunk oldVersion, BufferPool.Chunk newVersion) + { + newVersion.debugAttachment = oldVersion.debugAttachment; + DebugChunk.get(oldVersion).lastRecycled = recycleRound; + } + public synchronized void check() + { +// for (BufferPool.Chunk chunk : tinyChunks) +// assert DebugChunk.get(chunk).lastRecycled == recycleRound; + for (BufferPool.Chunk chunk : normalChunks) + assert DebugChunk.get(chunk).lastRecycled == recycleRound; + tinyChunks.clear(); // they don't survive a recycleRound + recycleRound++; + } + } + @BeforeClass public static void setup() throws Exception { @@ -133,7 +170,7 @@ private static final class TestEnvironment makingProgress = new AtomicBoolean[threadCount]; burnFreed = new AtomicBoolean(false); freedAllMemory = new AtomicBoolean[threadCount]; - executorService = Executors.newFixedThreadPool(threadCount + 2); + executorService = Executors.newFixedThreadPool(threadCount + 2, new NamedThreadFactory("test")); threadResultFuture = new ArrayList<>(threadCount); for (int i = 0; i < sharedRecycle.length; i++) @@ -149,7 +186,7 @@ private static final class TestEnvironment // using their own algorithm the targetSize should be poolSize / targetSizeQuanta. // // This should divide double the poolSize across the working threads, - // plus CHUNK_SIZE for thread0 and 1/10 poolSize for the burn producer/consumer pair. + // plus NORMAL_CHUNK_SIZE for thread0 and 1/10 poolSize for the burn producer/consumer pair. targetSizeQuanta = 2 * poolSize / sum1toN(threadCount - 1); } @@ -209,7 +246,8 @@ public void testAllocate(int threadCount, long duration, int poolSize) throws In long prevPoolSize = BufferPool.MEMORY_USAGE_THRESHOLD; logger.info("Overriding configured BufferPool.MEMORY_USAGE_THRESHOLD={} and enabling BufferPool.DEBUG", poolSize); BufferPool.MEMORY_USAGE_THRESHOLD = poolSize; - BufferPool.DEBUG = true; + Debug debug = new Debug(); + BufferPool.debug(debug); TestEnvironment testEnv = new TestEnvironment(threadCount, duration, poolSize); @@ -230,7 +268,7 @@ public void testAllocate(int threadCount, long duration, int poolSize) throws In for (AtomicBoolean freedMemory : testEnv.freedAllMemory) allFreed = allFreed && freedMemory.getAndSet(false); if (allFreed) - BufferPool.assertAllRecycled(); + debug.check(); else logger.info("All threads did not free all memory in this time slot - skipping buffer recycle check"); } @@ -250,7 +288,7 @@ public void testAllocate(int threadCount, long duration, int poolSize) throws In logger.info("Reverting BufferPool.MEMORY_USAGE_THRESHOLD={}", prevPoolSize); BufferPool.MEMORY_USAGE_THRESHOLD = prevPoolSize; - BufferPool.DEBUG = false; + BufferPool.debug(null); testEnv.assertCheckedThreadsSucceeded(); @@ -262,7 +300,7 @@ private Future startWorkerThread(TestEnvironment testEnv, final int thr { return testEnv.executorService.submit(new TestUntil(testEnv.until) { - final int targetSize = threadIdx == 0 ? BufferPool.CHUNK_SIZE : testEnv.targetSizeQuanta * threadIdx; + final int targetSize = threadIdx == 0 ? BufferPool.NORMAL_CHUNK_SIZE : testEnv.targetSizeQuanta * threadIdx; final SPSCQueue shareFrom = testEnv.sharedRecycle[threadIdx]; final DynamicList checks = new DynamicList<>((int) Math.max(1, targetSize / (1 << 10))); final SPSCQueue shareTo = testEnv.sharedRecycle[(threadIdx + 1) % testEnv.threadCount]; @@ -279,7 +317,6 @@ void checkpoint() void testOne() throws Exception { - long currentTargetSize = (rand.nextInt(testEnv.poolSize / 1024) == 0 || !testEnv.freedAllMemory[threadIdx].get()) ? 0 : targetSize; int spinCount = 0; while (totalSize > currentTargetSize - freeingSize) @@ -309,8 +346,8 @@ else if (!recycleFromNeighbour()) checks.remove(check.listnode); check.validate(); - size = BufferPool.roundUpNormal(check.buffer.capacity()); - if (size > BufferPool.CHUNK_SIZE) + size = BufferPool.roundUp(check.buffer.capacity()); + if (size > BufferPool.NORMAL_CHUNK_SIZE) size = 0; // either share to free, or free immediately @@ -334,9 +371,9 @@ else if (!recycleFromNeighbour()) // allocate a new buffer size = (int) Math.max(1, AVG_BUFFER_SIZE + (STDEV_BUFFER_SIZE * rand.nextGaussian())); - if (size <= BufferPool.CHUNK_SIZE) + if (size <= BufferPool.NORMAL_CHUNK_SIZE) { - totalSize += BufferPool.roundUpNormal(size); + totalSize += BufferPool.roundUp(size); allocate(size); } else if (rand.nextBoolean()) @@ -349,10 +386,10 @@ else if (rand.nextBoolean()) while (totalSize < testEnv.poolSize) { size = (int) Math.max(1, AVG_BUFFER_SIZE + (STDEV_BUFFER_SIZE * rand.nextGaussian())); - if (size <= BufferPool.CHUNK_SIZE) + if (size <= BufferPool.NORMAL_CHUNK_SIZE) { allocate(size); - totalSize += BufferPool.roundUpNormal(size); + totalSize += BufferPool.roundUp(size); } } } @@ -443,7 +480,7 @@ private void startBurnerThreads(TestEnvironment testEnv) final ThreadLocalRandom rand = ThreadLocalRandom.current(); void testOne() throws Exception { - if (count * BufferPool.CHUNK_SIZE >= testEnv.poolSize / 10) + if (count * BufferPool.NORMAL_CHUNK_SIZE >= testEnv.poolSize / 10) { if (burn.exhausted) { @@ -456,7 +493,8 @@ void testOne() throws Exception return; } - ByteBuffer buffer = BufferPool.tryGet(BufferPool.CHUNK_SIZE); + ByteBuffer buffer = rand.nextInt(4) < 1 ? BufferPool.tryGet(BufferPool.NORMAL_CHUNK_SIZE) + : BufferPool.tryGet(BufferPool.TINY_ALLOCATION_LIMIT); if (buffer == null) { Thread.yield(); @@ -523,7 +561,7 @@ public Boolean call() throws Exception { logger.error("Got exception {}, current chunk {}", ex.getMessage(), - BufferPool.currentChunk()); + BufferPool.unsafeCurrentChunk()); ex.printStackTrace(); return false; } @@ -531,7 +569,7 @@ public Boolean call() throws Exception { logger.error("Got throwable {}, current chunk {}", tr.getMessage(), - BufferPool.currentChunk()); + BufferPool.unsafeCurrentChunk()); tr.printStackTrace(); return false; } @@ -555,6 +593,7 @@ public static void main(String[] args) catch (Throwable tr) { System.out.println(String.format("Test failed - %s", tr.getMessage())); + tr.printStackTrace(); System.exit(1); // Force exit so that non-daemon threads like REQUEST-SCHEDULER do not hang the process on failure } } diff --git a/test/conf/cassandra-murmur.yaml b/test/conf/cassandra-murmur.yaml index e933837ccdee..3c263a5af17c 100644 --- a/test/conf/cassandra-murmur.yaml +++ b/test/conf/cassandra-murmur.yaml @@ -13,7 +13,7 @@ cdc_enabled: false hints_directory: build/test/cassandra/hints partitioner: org.apache.cassandra.dht.Murmur3Partitioner listen_address: 127.0.0.1 -storage_port: 7010 +storage_port: 7012 start_native_transport: true native_transport_port: 9042 column_index_size_in_kb: 4 @@ -24,7 +24,7 @@ disk_access_mode: mmap seed_provider: - class_name: org.apache.cassandra.locator.SimpleSeedProvider parameters: - - seeds: "127.0.0.1" + - seeds: "127.0.0.1:7012" endpoint_snitch: org.apache.cassandra.locator.SimpleSnitch dynamic_snitch: true server_encryption_options: diff --git a/test/conf/cassandra-seeds.yaml b/test/conf/cassandra-seeds.yaml index 02d25d232cfb..f3279aeb1253 100644 --- a/test/conf/cassandra-seeds.yaml +++ b/test/conf/cassandra-seeds.yaml @@ -14,7 +14,7 @@ cdc_enabled: false hints_directory: build/test/cassandra/hints partitioner: org.apache.cassandra.dht.ByteOrderedPartitioner listen_address: 127.0.0.1 -storage_port: 7010 +storage_port: 7012 start_native_transport: true native_transport_port: 9042 column_index_size_in_kb: 4 diff --git a/test/conf/cassandra.yaml b/test/conf/cassandra.yaml index d94c478b9a41..89b7ff180769 100644 --- a/test/conf/cassandra.yaml +++ b/test/conf/cassandra.yaml @@ -16,7 +16,7 @@ cdc_enabled: false hints_directory: build/test/cassandra/hints partitioner: org.apache.cassandra.dht.ByteOrderedPartitioner listen_address: 127.0.0.1 -storage_port: 7010 +storage_port: 7012 ssl_storage_port: 7011 start_native_transport: true native_transport_port: 9042 @@ -28,7 +28,7 @@ disk_access_mode: mmap seed_provider: - class_name: org.apache.cassandra.locator.SimpleSeedProvider parameters: - - seeds: "127.0.0.1:7010" + - seeds: "127.0.0.1:7012" endpoint_snitch: org.apache.cassandra.locator.SimpleSnitch dynamic_snitch: true server_encryption_options: diff --git a/test/conf/logback-burntest.xml b/test/conf/logback-burntest.xml new file mode 100644 index 000000000000..e1e48a9d3fae --- /dev/null +++ b/test/conf/logback-burntest.xml @@ -0,0 +1,66 @@ + + + + + + + + + + + ./build/test/logs/${cassandra.testtag}/TEST-${suitename}.log + + ./build/test/logs/${cassandra.testtag}/TEST-${suitename}.log.%i.gz + 1 + 20 + + + + 20MB + + + + %-5level [%thread] ${instance_id} %date{ISO8601} %msg%n + + false + + + + 0 + 0 + 1024 + + + + + + + %-5level [%thread] ${instance_id} %date{ISO8601} %F:%L - %msg%n + + + + + + + + + + + + diff --git a/test/conf/logback-dtest.xml b/test/conf/logback-dtest.xml index b62539fb3c58..370e1e5bb224 100644 --- a/test/conf/logback-dtest.xml +++ b/test/conf/logback-dtest.xml @@ -23,7 +23,7 @@ - + ./build/test/logs/${cassandra.testtag}/TEST-${suitename}.log @@ -42,14 +42,14 @@ false - + 0 0 1024 - + - + %-5level %date{HH:mm:ss,SSS} %msg%n @@ -58,16 +58,7 @@ - - - %-5level %date{HH:mm:ss,SSS} %msg%n - - - WARN - - - - + %-5level [%thread] ${instance_id} %date{ISO8601} %F:%L - %msg%n @@ -79,8 +70,8 @@ - - - + + + diff --git a/test/data/serialization/4.0/service.SyncComplete.bin b/test/data/serialization/4.0/service.SyncComplete.bin index 15cccb85be8e..4e8caa6c3743 100644 Binary files a/test/data/serialization/4.0/service.SyncComplete.bin and b/test/data/serialization/4.0/service.SyncComplete.bin differ diff --git a/test/data/serialization/4.0/service.SyncRequest.bin b/test/data/serialization/4.0/service.SyncRequest.bin index f4eb53285db6..b0cc44eee16c 100644 Binary files a/test/data/serialization/4.0/service.SyncRequest.bin and b/test/data/serialization/4.0/service.SyncRequest.bin differ diff --git a/test/data/serialization/4.0/service.ValidationComplete.bin b/test/data/serialization/4.0/service.ValidationComplete.bin index edc90b359706..7402c9e848d5 100644 Binary files a/test/data/serialization/4.0/service.ValidationComplete.bin and b/test/data/serialization/4.0/service.ValidationComplete.bin differ diff --git a/test/data/serialization/4.0/service.ValidationRequest.bin b/test/data/serialization/4.0/service.ValidationRequest.bin index e45eb703e088..fa4a9138426b 100644 Binary files a/test/data/serialization/4.0/service.ValidationRequest.bin and b/test/data/serialization/4.0/service.ValidationRequest.bin differ diff --git a/test/distributed/org/apache/cassandra/distributed/Cluster.java b/test/distributed/org/apache/cassandra/distributed/Cluster.java index c7f7675cb847..95862b6a627a 100644 --- a/test/distributed/org/apache/cassandra/distributed/Cluster.java +++ b/test/distributed/org/apache/cassandra/distributed/Cluster.java @@ -20,8 +20,8 @@ import java.io.File; import java.io.IOException; -import java.nio.file.Files; import java.util.List; +import java.util.function.Consumer; import org.apache.cassandra.distributed.api.ICluster; import org.apache.cassandra.distributed.impl.AbstractCluster; @@ -40,18 +40,24 @@ private Cluster(File root, Versions.Version version, List config super(root, version, configs, sharedClassLoader); } - protected IInvokableInstance newInstanceWrapper(Versions.Version version, InstanceConfig config) + protected IInvokableInstance newInstanceWrapper(int generation, Versions.Version version, InstanceConfig config) { - return new Wrapper(version, config); + return new Wrapper(generation, version, config); } - public static Cluster create(int nodeCount) throws Throwable + public static Builder build(int nodeCount) + { + return new Builder<>(nodeCount, Cluster::new); + } + + public static Cluster create(int nodeCount, Consumer configUpdater) throws IOException { - return create(nodeCount, Cluster::new); + return build(nodeCount).withConfig(configUpdater).start(); } - public static Cluster create(int nodeCount, File root) + + public static Cluster create(int nodeCount) throws Throwable { - return create(nodeCount, Versions.CURRENT, root, Cluster::new); + return build(nodeCount).start(); } } diff --git a/test/distributed/org/apache/cassandra/distributed/UpgradeableCluster.java b/test/distributed/org/apache/cassandra/distributed/UpgradeableCluster.java index 0c8e63ae9101..9a270376c07e 100644 --- a/test/distributed/org/apache/cassandra/distributed/UpgradeableCluster.java +++ b/test/distributed/org/apache/cassandra/distributed/UpgradeableCluster.java @@ -19,12 +19,11 @@ package org.apache.cassandra.distributed; import java.io.File; -import java.io.IOException; -import java.nio.file.Files; import java.util.List; import org.apache.cassandra.distributed.api.ICluster; import org.apache.cassandra.distributed.impl.AbstractCluster; +import org.apache.cassandra.distributed.impl.IInvokableInstance; import org.apache.cassandra.distributed.impl.IUpgradeableInstance; import org.apache.cassandra.distributed.impl.InstanceConfig; import org.apache.cassandra.distributed.impl.Versions; @@ -43,28 +42,24 @@ private UpgradeableCluster(File root, Versions.Version version, List build(int nodeCount) { - return create(nodeCount, Versions.CURRENT, root, UpgradeableCluster::new); + return new Builder<>(nodeCount, UpgradeableCluster::new); } - public static UpgradeableCluster create(int nodeCount, Versions.Version version) throws IOException + public static UpgradeableCluster create(int nodeCount) throws Throwable { - return create(nodeCount, version, Files.createTempDirectory("dtests").toFile(), UpgradeableCluster::new); + return build(nodeCount).start(); } - public static UpgradeableCluster create(int nodeCount, Versions.Version version, File root) + + public static UpgradeableCluster create(int nodeCount, Versions.Version version) throws Throwable { - return create(nodeCount, version, root, UpgradeableCluster::new); + return build(nodeCount).withVersion(version).start(); } - } diff --git a/test/distributed/org/apache/cassandra/distributed/api/Feature.java b/test/distributed/org/apache/cassandra/distributed/api/Feature.java new file mode 100644 index 000000000000..a5c9316930e3 --- /dev/null +++ b/test/distributed/org/apache/cassandra/distributed/api/Feature.java @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.distributed.api; + +public enum Feature +{ + NETWORK, GOSSIP +} diff --git a/test/distributed/org/apache/cassandra/distributed/api/IInstance.java b/test/distributed/org/apache/cassandra/distributed/api/IInstance.java index 3834093fbbec..d5382b4970f0 100644 --- a/test/distributed/org/apache/cassandra/distributed/api/IInstance.java +++ b/test/distributed/org/apache/cassandra/distributed/api/IInstance.java @@ -37,7 +37,9 @@ public interface IInstance extends IIsolatedExecutor UUID schemaVersion(); void startup(); + boolean isShutdown(); Future shutdown(); + Future shutdown(boolean graceful); // these methods are not for external use, but for simplicity we leave them public and on the normal IInstance interface void startup(ICluster cluster); diff --git a/test/distributed/org/apache/cassandra/distributed/api/IInstanceConfig.java b/test/distributed/org/apache/cassandra/distributed/api/IInstanceConfig.java index 6741b3fdd59b..3e5a18fe7c8c 100644 --- a/test/distributed/org/apache/cassandra/distributed/api/IInstanceConfig.java +++ b/test/distributed/org/apache/cassandra/distributed/api/IInstanceConfig.java @@ -38,4 +38,5 @@ public interface IInstanceConfig Object get(String fieldName); String getString(String fieldName); int getInt(String fieldName); + boolean has(Feature featureFlag); } diff --git a/test/distributed/org/apache/cassandra/distributed/api/IMessage.java b/test/distributed/org/apache/cassandra/distributed/api/IMessage.java index 1e537ed1ee82..7bc7931a83c1 100644 --- a/test/distributed/org/apache/cassandra/distributed/api/IMessage.java +++ b/test/distributed/org/apache/cassandra/distributed/api/IMessage.java @@ -27,6 +27,7 @@ public interface IMessage { int verb(); byte[] bytes(); + // TODO: need to make this a long int id(); int version(); InetAddressAndPort from(); diff --git a/test/distributed/org/apache/cassandra/distributed/api/IMessageFilters.java b/test/distributed/org/apache/cassandra/distributed/api/IMessageFilters.java index b5fde840e748..426aa5edef94 100644 --- a/test/distributed/org/apache/cassandra/distributed/api/IMessageFilters.java +++ b/test/distributed/org/apache/cassandra/distributed/api/IMessageFilters.java @@ -18,10 +18,8 @@ package org.apache.cassandra.distributed.api; -import org.apache.cassandra.locator.InetAddressAndPort; import org.apache.cassandra.net.MessagingService; - -import java.util.function.BiConsumer; +import org.apache.cassandra.net.Verb; public interface IMessageFilters { @@ -39,10 +37,11 @@ public interface Builder Filter drop(); } - Builder verbs(MessagingService.Verb... verbs); + Builder verbs(Verb ... verbs); Builder allVerbs(); void reset(); // internal - BiConsumer filter(BiConsumer applyIfNotFiltered); + boolean permit(IInstance from, IInstance to, int verb); + } diff --git a/test/distributed/org/apache/cassandra/distributed/impl/AbstractCluster.java b/test/distributed/org/apache/cassandra/distributed/impl/AbstractCluster.java index 93a749815f63..07dd99e1be4f 100644 --- a/test/distributed/org/apache/cassandra/distributed/impl/AbstractCluster.java +++ b/test/distributed/org/apache/cassandra/distributed/impl/AbstractCluster.java @@ -30,11 +30,9 @@ import java.util.List; import java.util.Map; import java.util.Set; -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.ExecutionException; -import java.util.concurrent.Executors; import java.util.concurrent.Future; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Consumer; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -57,7 +55,7 @@ import org.apache.cassandra.distributed.api.ICluster; import org.apache.cassandra.io.util.FileUtils; import org.apache.cassandra.locator.InetAddressAndPort; -import org.apache.cassandra.net.MessagingService; +import org.apache.cassandra.net.Verb; import org.apache.cassandra.utils.FBUtilities; import org.apache.cassandra.utils.concurrent.SimpleCondition; @@ -91,6 +89,7 @@ public abstract class AbstractCluster implements ICluster, // to ensure we have instantiated the main classloader's LoggerFactory (and any LogbackStatusListener) // before we instantiate any for a new instance private static final Logger logger = LoggerFactory.getLogger(AbstractCluster.class); + private static final AtomicInteger generation = new AtomicInteger(); private final File root; private final ClassLoader sharedClassLoader; @@ -104,6 +103,7 @@ public abstract class AbstractCluster implements ICluster, protected class Wrapper extends DelegatingInvokableInstance implements IUpgradeableInstance { + private final int generation; private final InstanceConfig config; private volatile IInvokableInstance delegate; private volatile Versions.Version version; @@ -112,21 +112,22 @@ protected class Wrapper extends DelegatingInvokableInstance implements IUpgradea protected IInvokableInstance delegate() { if (delegate == null) - delegate = newInstance(); + delegate = newInstance(generation); return delegate; } - public Wrapper(Versions.Version version, InstanceConfig config) + public Wrapper(int generation, Versions.Version version, InstanceConfig config) { + this.generation = generation; this.config = config; this.version = version; // we ensure there is always a non-null delegate, so that the executor may be used while the node is offline - this.delegate = newInstance(); + this.delegate = newInstance(generation); } - private IInvokableInstance newInstance() + private IInvokableInstance newInstance(int generation) { - ClassLoader classLoader = new InstanceClassLoader(config.num(), version.classpath, sharedClassLoader); + ClassLoader classLoader = new InstanceClassLoader(generation, version.classpath, sharedClassLoader); return Instance.transferAdhoc((SerializableBiFunction)Instance::new, classLoader) .apply(config.forVersion(version.major), classLoader); } @@ -136,6 +137,11 @@ public IInstanceConfig config() return config; } + public boolean isShutdown() + { + return isShutdown; + } + @Override public synchronized void startup() { @@ -148,11 +154,17 @@ public synchronized void startup() @Override public synchronized Future shutdown() + { + return shutdown(true); + } + + @Override + public synchronized Future shutdown(boolean graceful) { if (isShutdown) throw new IllegalStateException(); isShutdown = true; - Future future = delegate.shutdown(); + Future future = delegate.shutdown(graceful); delegate = null; return future; } @@ -187,19 +199,20 @@ protected AbstractCluster(File root, Versions.Version version, List(); this.instanceMap = new HashMap<>(); + int generation = AbstractCluster.generation.incrementAndGet(); for (InstanceConfig config : configs) { - I instance = newInstanceWrapper(version, config); + I instance = newInstanceWrapper(generation, version, config); instances.add(instance); // we use the config().broadcastAddressAndPort() here because we have not initialised the Instance I prev = instanceMap.put(instance.broadcastAddressAndPort(), instance); if (null != prev) throw new IllegalStateException("Cluster cannot have multiple nodes with same InetAddressAndPort: " + instance.broadcastAddressAndPort() + " vs " + prev.broadcastAddressAndPort()); } - this.filters = new MessageFilters(this); + this.filters = new MessageFilters(); } - protected abstract I newInstanceWrapper(Versions.Version version, InstanceConfig config); + protected abstract I newInstanceWrapper(int generation, Versions.Version version, InstanceConfig config); /** * WARNING: we index from 1 here, for consistency with inet address! @@ -231,7 +244,7 @@ public void parallelForEach(IIsolatedExecutor.SerializableConsumer co public IMessageFilters filters() { return filters; } - public MessageFilters.Builder verbs(MessagingService.Verb ... verbs) { return filters.verbs(verbs); } + public MessageFilters.Builder verbs(Verb... verbs) { return filters.verbs(verbs); } public void disableAutoCompaction(String keyspace) { @@ -257,9 +270,12 @@ private void updateMessagingVersions() { for (IInstance reportTo: instances) { + if (reportTo.isShutdown()) + continue; + for (IInstance reportFrom: instances) { - if (reportFrom == reportTo) + if (reportFrom == reportTo || reportFrom.isShutdown()) continue; int minVersion = Math.min(reportFrom.getMessagingVersion(), reportTo.getMessagingVersion()); @@ -335,46 +351,83 @@ protected interface Factory> C newCluster(File root, Versions.Version version, List configs, ClassLoader sharedClassLoader); } - protected static > C - create(int nodeCount, Factory factory) throws Throwable + public static class Builder> { - return create(nodeCount, Files.createTempDirectory("dtests").toFile(), factory); - } + private final int nodeCount; + private final Factory factory; + private int subnet; + private File root; + private Versions.Version version; + private Consumer configUpdater; + public Builder(int nodeCount, Factory factory) + { + this.nodeCount = nodeCount; + this.factory = factory; + } - protected static > C - create(int nodeCount, File root, Factory factory) - { - return create(nodeCount, Versions.CURRENT, root, factory); - } + public Builder withSubnet(int subnet) + { + this.subnet = subnet; + return this; + } - protected static > C - create(int nodeCount, Versions.Version version, Factory factory) throws IOException - { - return create(nodeCount, version, Files.createTempDirectory("dtests").toFile(), factory); - } + public Builder withRoot(File root) + { + this.root = root; + return this; + } - protected static > C - create(int nodeCount, Versions.Version version, File root, Factory factory) - { - root.mkdirs(); - setupLogging(root); + public Builder withVersion(Versions.Version version) + { + this.version = version; + return this; + } - ClassLoader sharedClassLoader = Thread.currentThread().getContextClassLoader(); + public Builder withConfig(Consumer updater) + { + this.configUpdater = updater; + return this; + } - List configs = new ArrayList<>(); - long token = Long.MIN_VALUE + 1, increment = 2 * (Long.MAX_VALUE / nodeCount); - for (int i = 0 ; i < nodeCount ; ++i) + public C createWithoutStarting() throws IOException { - InstanceConfig config = InstanceConfig.generate(i + 1, root, String.valueOf(token)); - configs.add(config); - token += increment; + File root = this.root; + Versions.Version version = this.version; + + if (root == null) + root = Files.createTempDirectory("dtests").toFile(); + if (version == null) + version = Versions.CURRENT; + + root.mkdirs(); + setupLogging(root); + + ClassLoader sharedClassLoader = Thread.currentThread().getContextClassLoader(); + + List configs = new ArrayList<>(); + long token = Long.MIN_VALUE + 1, increment = 2 * (Long.MAX_VALUE / nodeCount); + for (int i = 0; i < nodeCount; ++i) + { + InstanceConfig config = InstanceConfig.generate(i + 1, subnet, root, String.valueOf(token)); + if (configUpdater != null) + configUpdater.accept(config); + configs.add(config); + token += increment; + } + + C cluster = factory.newCluster(root, version, configs, sharedClassLoader); + return cluster; } - C cluster = factory.newCluster(root, version, configs, sharedClassLoader); - cluster.startup(); - return cluster; + public C start() throws IOException + { + C cluster = createWithoutStarting(); + cluster.startup(); + return cluster; + } } + private static void setupLogging(File root) { try @@ -398,6 +451,7 @@ private static void setupLogging(File root) public void close() { FBUtilities.waitOnFutures(instances.stream() + .filter(i -> !i.isShutdown()) .map(IInstance::shutdown) .collect(Collectors.toList()), 1L, TimeUnit.MINUTES); diff --git a/src/java/org/apache/cassandra/io/ShortVersionedSerializer.java b/test/distributed/org/apache/cassandra/distributed/impl/ExecUtil.java similarity index 53% rename from src/java/org/apache/cassandra/io/ShortVersionedSerializer.java rename to test/distributed/org/apache/cassandra/distributed/impl/ExecUtil.java index 8731f4c94f7c..b907626ae85a 100644 --- a/src/java/org/apache/cassandra/io/ShortVersionedSerializer.java +++ b/test/distributed/org/apache/cassandra/distributed/impl/ExecUtil.java @@ -16,32 +16,36 @@ * limitations under the License. */ -package org.apache.cassandra.io; +package org.apache.cassandra.distributed.impl; -import java.io.IOException; +import java.io.Serializable; -import org.apache.cassandra.io.util.DataInputPlus; -import org.apache.cassandra.io.util.DataOutputPlus; +import org.apache.cassandra.distributed.api.IIsolatedExecutor; -public class ShortVersionedSerializer implements IVersionedSerializer +public class ExecUtil { - public static final ShortVersionedSerializer instance = new ShortVersionedSerializer(); - - private ShortVersionedSerializer() {} - - public void serialize(Short aShort, DataOutputPlus out, int version) throws IOException + public interface ThrowingSerializableRunnable extends Serializable { - out.writeShort(aShort); + public void run() throws T; } - public Short deserialize(DataInputPlus in, int version) throws IOException + public static IIsolatedExecutor.SerializableRunnable rethrow(ThrowingSerializableRunnable run) { - return in.readShort(); + return () -> { + try + { + run.run(); + } + catch (RuntimeException | Error t) + { + throw t; + } + catch (Throwable t) + { + throw new RuntimeException(t); + } + }; } - public long serializedSize(Short aShort, int version) - { - return 2; - } } diff --git a/test/distributed/org/apache/cassandra/distributed/impl/Instance.java b/test/distributed/org/apache/cassandra/distributed/impl/Instance.java index 53e109afb89d..9294bfee4574 100644 --- a/test/distributed/org/apache/cassandra/distributed/impl/Instance.java +++ b/test/distributed/org/apache/cassandra/distributed/impl/Instance.java @@ -27,12 +27,12 @@ import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutorService; import java.util.concurrent.Future; -import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; import java.util.function.BiConsumer; +import java.util.function.BiPredicate; -import org.slf4j.LoggerFactory; +import io.netty.util.concurrent.GlobalEventExecutor; -import ch.qos.logback.classic.LoggerContext; import org.apache.cassandra.batchlog.BatchlogManager; import org.apache.cassandra.concurrent.ScheduledExecutors; import org.apache.cassandra.concurrent.SharedExecutorPool; @@ -62,17 +62,13 @@ import org.apache.cassandra.gms.VersionedValue; import org.apache.cassandra.hints.HintsService; import org.apache.cassandra.index.SecondaryIndexManager; +import org.apache.cassandra.io.sstable.IndexSummaryManager; import org.apache.cassandra.io.sstable.format.SSTableReader; import org.apache.cassandra.io.util.DataInputBuffer; import org.apache.cassandra.io.util.DataOutputBuffer; import org.apache.cassandra.locator.InetAddressAndPort; -import org.apache.cassandra.net.IMessageSink; -import org.apache.cassandra.net.MessageDeliveryTask; -import org.apache.cassandra.net.MessageIn; -import org.apache.cassandra.net.MessageOut; +import org.apache.cassandra.net.Message; import org.apache.cassandra.net.MessagingService; -import org.apache.cassandra.net.async.MessageInHandler; -import org.apache.cassandra.net.async.NettyFactory; import org.apache.cassandra.schema.Schema; import org.apache.cassandra.schema.SchemaConstants; import org.apache.cassandra.service.ActiveRepairService; @@ -80,12 +76,20 @@ import org.apache.cassandra.service.PendingRangeCalculatorService; import org.apache.cassandra.service.QueryState; import org.apache.cassandra.service.StorageService; +import org.apache.cassandra.streaming.async.StreamingInboundHandler; +import org.apache.cassandra.streaming.StreamReceiveTask; +import org.apache.cassandra.streaming.StreamTransferTask; import org.apache.cassandra.transport.messages.ResultMessage; +import org.apache.cassandra.utils.ExecutorUtils; import org.apache.cassandra.utils.FBUtilities; import org.apache.cassandra.utils.Throwables; import org.apache.cassandra.utils.concurrent.Ref; import org.apache.cassandra.utils.memory.BufferPool; +import static java.util.concurrent.TimeUnit.MINUTES; +import static org.apache.cassandra.distributed.api.Feature.GOSSIP; +import static org.apache.cassandra.distributed.api.Feature.NETWORK; + public class Instance extends IsolatedExecutor implements IInvokableInstance { public final IInstanceConfig config; @@ -98,6 +102,14 @@ public class Instance extends IsolatedExecutor implements IInvokableInstance this.config = config; InstanceIDDefiner.setInstanceId(config.num()); FBUtilities.setBroadcastInetAddressAndPort(config.broadcastAddressAndPort()); + // Set the config at instance creation, possibly before startup() has run on all other instances. + // setMessagingVersions below will call runOnInstance which will instantiate + // the MessagingService and dependencies preventing later changes to network parameters. + Config.setOverrideLoadConfig(() -> loadConfig(config)); + + // Enable streaming inbound handler tracking so they can be closed properly without leaking + // the blocking IO thread. + StreamingInboundHandler.trackInboundHandlers(); } @Override @@ -148,6 +160,11 @@ public void startup() throw new UnsupportedOperationException(); } + public boolean isShutdown() + { + throw new UnsupportedOperationException(); + } + @Override public void schemaChangeInternal(String query) { @@ -173,12 +190,21 @@ public void schemaChangeInternal(String query) private void registerMockMessaging(ICluster cluster) { BiConsumer deliverToInstance = (to, message) -> cluster.get(to).receiveMessage(message); - BiConsumer deliverToInstanceIfNotFiltered = cluster.filters().filter(deliverToInstance); + BiConsumer deliverToInstanceIfNotFiltered = (to, message) -> { + if (cluster.filters().permit(this, cluster.get(to), message.verb())) + deliverToInstance.accept(to, message); + }; - MessagingService.instance().addMessageSink(new MessageDeliverySink(deliverToInstanceIfNotFiltered)); + MessagingService.instance().outboundSink.add(new MessageDeliverySink(deliverToInstanceIfNotFiltered)); } - private class MessageDeliverySink implements IMessageSink + // unnecessary if registerMockMessaging used + private void registerFilter(ICluster cluster) + { + MessagingService.instance().outboundSink.add((message, to) -> cluster.filters().permit(this, cluster.get(to), message.verb().id)); + } + + private class MessageDeliverySink implements BiPredicate, InetAddressAndPort> { private final BiConsumer deliver; MessageDeliverySink(BiConsumer deliver) @@ -187,14 +213,13 @@ private class MessageDeliverySink implements IMessageSink } @Override - public boolean allowOutgoingMessage(MessageOut messageOut, int id, InetAddressAndPort to) + public boolean test(Message messageOut, InetAddressAndPort to) { try (DataOutputBuffer out = new DataOutputBuffer(1024)) { InetAddressAndPort from = broadcastAddressAndPort(); - int version = MessagingService.instance().getVersion(to); - messageOut.serialize(out, version); - deliver.accept(to, new Message(messageOut.verb.getId(), out.toByteArray(), id, version, from)); + Message.serializer.serialize(messageOut, out, MessagingService.current_version); + deliver.accept(to, new MessageImpl(messageOut.verb().id, out.toByteArray(), messageOut.id(), MessagingService.current_version, from)); } catch (IOException e) { @@ -202,13 +227,6 @@ public boolean allowOutgoingMessage(MessageOut messageOut, int id, InetAddressAn } return false; } - - @Override - public boolean allowIncomingMessage(MessageIn message, int id) - { - // we can filter to our heart's content on the outgoing message; no need to worry about incoming - return true; - } } @Override @@ -217,9 +235,8 @@ public void receiveMessage(IMessage message) sync(() -> { try (DataInputBuffer in = new DataInputBuffer(message.bytes())) { - MessageIn messageIn = MessageInHandler.deserialize(in, message.id(), message.version(), message.from()); - Runnable deliver = new MessageDeliveryTask(messageIn, message.id()); - deliver.run(); + Message messageIn = Message.serializer.deserialize(in, message.from(), message.version()); + messageIn.verb().handler().doVerb((Message) messageIn); } catch (Throwable t) { @@ -235,7 +252,7 @@ public int getMessagingVersion() public void setMessagingVersion(InetAddressAndPort endpoint, int version) { - runOnInstance(() -> MessagingService.instance().setVersion(endpoint, version)); + MessagingService.instance().versions.set(endpoint, version); } @Override @@ -244,10 +261,17 @@ public void startup(ICluster cluster) sync(() -> { try { + if (config.has(GOSSIP)) + { + // TODO: hacky + System.setProperty("cassandra.ring_delay_ms", "5000"); + System.setProperty("cassandra.consistent.rangemovement", "false"); + System.setProperty("cassandra.consistent.simultaneousmoves.allow", "true"); + } + mkdirs(); - Config.setOverrideLoadConfig(() -> loadConfig(config)); - DatabaseDescriptor.daemonInitialization(); + DatabaseDescriptor.daemonInitialization(); DatabaseDescriptor.createAllDirectories(); // We need to persist this as soon as possible after startup checks. @@ -277,13 +301,28 @@ public void startup(ICluster cluster) throw new RuntimeException(e); } - // Even though we don't use MessagingService, access the static NettyFactory - // instance here so that we start the static event loop state - // (e.g. acceptGroup, inboundGroup, outboundGroup, etc ...). We can remove this - // once we actually use the MessagingService to communicate between nodes - NettyFactory.instance.getClass(); - initializeRing(cluster); - registerMockMessaging(cluster); + if (config.has(NETWORK)) + { + registerFilter(cluster); + MessagingService.instance().listen(); + } + else + { + // Even though we don't use MessagingService, access the static SocketFactory + // instance here so that we start the static event loop state +// -- not sure what that means? SocketFactory.instance.getClass(); + registerMockMessaging(cluster); + } + + // TODO: this is more than just gossip + if (config.has(GOSSIP)) + { + StorageService.instance.initServer(); + } + else + { + initializeRing(cluster); + } SystemKeyspace.finishStartup(); @@ -356,8 +395,7 @@ private void initializeRing(ICluster cluster) new VersionedValue.VersionedValueFactory(partitioner).normal(Collections.singleton(token))); Gossiper.instance.realMarkAlive(ep, Gossiper.instance.getEndpointStateForEndpoint(ep)); }); - int version = Math.min(MessagingService.current_version, cluster.get(ep).getMessagingVersion()); - MessagingService.instance().setVersion(ep, version); + MessagingService.instance().versions.set(ep, MessagingService.current_version); } // check that all nodes are in token metadata @@ -370,40 +408,53 @@ private void initializeRing(ICluster cluster) } } - @Override public Future shutdown() + { + return shutdown(true); + } + + @Override + public Future shutdown(boolean graceful) { Future future = async((ExecutorService executor) -> { Throwable error = null; + + if (config.has(GOSSIP)) + { + StorageService.instance.shutdownServer(); + } + error = parallelRun(error, executor, - Gossiper.instance::stop, - CompactionManager.instance::forceShutdown, - BatchlogManager.instance::shutdown, - HintsService.instance::shutdownBlocking, - CommitLog.instance::shutdownBlocking, - SecondaryIndexManager::shutdownExecutors, - ColumnFamilyStore::shutdownFlushExecutor, - ColumnFamilyStore::shutdownPostFlushExecutor, - ColumnFamilyStore::shutdownReclaimExecutor, - ColumnFamilyStore::shutdownPerDiskFlushExecutors, - PendingRangeCalculatorService.instance::shutdownExecutor, - BufferPool::shutdownLocalCleaner, - Ref::shutdownReferenceReaper, - Memtable.MEMORY_POOL::shutdown, - ScheduledExecutors::shutdownAndWait, - SSTableReader::shutdownBlocking, - () -> shutdownAndWait(ActiveRepairService.repairCommandExecutor) + () -> Gossiper.instance.stopShutdownAndWait(1L, MINUTES), + CompactionManager.instance::forceShutdown, + () -> BatchlogManager.instance.shutdownAndWait(1L, MINUTES), + HintsService.instance::shutdownBlocking, + StreamingInboundHandler::shutdown, + () -> StreamReceiveTask.shutdownAndWait(1L, MINUTES), + () -> StreamTransferTask.shutdownAndWait(1L, MINUTES), + () -> SecondaryIndexManager.shutdownAndWait(1L, MINUTES), + () -> IndexSummaryManager.instance.shutdownAndWait(1L, MINUTES), + () -> ColumnFamilyStore.shutdownExecutorsAndWait(1L, MINUTES), + () -> PendingRangeCalculatorService.instance.shutdownAndWait(1L, MINUTES), + () -> BufferPool.shutdownLocalCleaner(1L, MINUTES), + () -> Ref.shutdownReferenceReaper(1L, MINUTES), + () -> Memtable.MEMORY_POOL.shutdownAndWait(1L, MINUTES), + () -> ScheduledExecutors.shutdownAndWait(1L, MINUTES), + () -> SSTableReader.shutdownBlocking(1L, MINUTES), + () -> shutdownAndWait(Collections.singletonList(ActiveRepairService.repairCommandExecutor)), + () -> ScheduledExecutors.shutdownAndWait(1L, MINUTES) ); + error = parallelRun(error, executor, - MessagingService.instance()::shutdown + CommitLog.instance::shutdownBlocking, + () -> MessagingService.instance().shutdown(1L, MINUTES, false, true) ); error = parallelRun(error, executor, - StageManager::shutdownAndWait, - SharedExecutorPool.SHARED::shutdown + () -> GlobalEventExecutor.INSTANCE.awaitInactivity(1l, MINUTES), + () -> StageManager.shutdownAndWait(1L, MINUTES), + () -> SharedExecutorPool.SHARED.shutdownAndWait(1L, MINUTES) ); - LoggerContext loggerContext = (LoggerContext) LoggerFactory.getILoggerFactory(); - loggerContext.stop(); Throwables.maybeFail(error); }).apply(isolatedExecutor); @@ -411,18 +462,10 @@ public Future shutdown() .thenRun(super::shutdown); } - private static void shutdownAndWait(ExecutorService executor) + private static void shutdownAndWait(List executors) throws TimeoutException, InterruptedException { - try - { - executor.shutdownNow(); - executor.awaitTermination(20, TimeUnit.SECONDS); - } - catch (InterruptedException e) - { - throw new RuntimeException(e); - } - assert executor.isTerminated() && executor.isShutdown() : executor; + ExecutorUtils.shutdownNow(executors); + ExecutorUtils.awaitTermination(1L, MINUTES, executors); } private static Throwable parallelRun(Throwable accumulate, ExecutorService runOn, ThrowingRunnable ... runnables) diff --git a/test/distributed/org/apache/cassandra/distributed/impl/InstanceClassLoader.java b/test/distributed/org/apache/cassandra/distributed/impl/InstanceClassLoader.java index 56c80740d9d3..aa45d272381d 100644 --- a/test/distributed/org/apache/cassandra/distributed/impl/InstanceClassLoader.java +++ b/test/distributed/org/apache/cassandra/distributed/impl/InstanceClassLoader.java @@ -20,6 +20,8 @@ import com.google.common.base.Predicate; import org.apache.cassandra.config.ParameterizedClass; +import org.apache.cassandra.io.util.DataOutputPlus; +import org.apache.cassandra.io.util.Memory; import org.apache.cassandra.locator.InetAddressAndPort; import org.apache.cassandra.utils.Pair; @@ -46,6 +48,7 @@ public class InstanceClassLoader extends URLClassLoader name.startsWith("org.apache.cassandra.distributed.api.") || name.startsWith("sun.") || name.startsWith("oracle.") + || name.startsWith("com.intellij.") || name.startsWith("com.sun.") || name.startsWith("com.oracle.") || name.startsWith("java.") @@ -61,16 +64,16 @@ public static interface Factory InstanceClassLoader create(int id, URL[] urls, ClassLoader sharedClassLoader); } - private final int id; private final URL[] urls; + private final int generation; // used to help debug class loader leaks, by helping determine which classloaders should have been collected private final ClassLoader sharedClassLoader; - InstanceClassLoader(int id, URL[] urls, ClassLoader sharedClassLoader) + InstanceClassLoader(int generation, URL[] urls, ClassLoader sharedClassLoader) { super(urls, null); - this.id = id; this.urls = urls; this.sharedClassLoader = sharedClassLoader; + this.generation = generation; } @Override @@ -107,7 +110,7 @@ public static boolean wasLoadedByAnInstanceClassLoader(Class clazz) public String toString() { return "InstanceClassLoader{" + - "id=" + id + + "generation=" + generation + ", urls=" + Arrays.toString(urls) + '}'; } diff --git a/test/distributed/org/apache/cassandra/distributed/impl/InstanceConfig.java b/test/distributed/org/apache/cassandra/distributed/impl/InstanceConfig.java index 6361995d75ae..8c8a774b018e 100644 --- a/test/distributed/org/apache/cassandra/distributed/impl/InstanceConfig.java +++ b/test/distributed/org/apache/cassandra/distributed/impl/InstanceConfig.java @@ -20,6 +20,7 @@ import org.apache.cassandra.config.Config; import org.apache.cassandra.config.ParameterizedClass; +import org.apache.cassandra.distributed.api.Feature; import org.apache.cassandra.distributed.api.IInstanceConfig; import org.apache.cassandra.locator.InetAddressAndPort; import org.apache.cassandra.locator.SimpleSeedProvider; @@ -30,6 +31,7 @@ import java.net.UnknownHostException; import java.util.Arrays; import java.util.Collections; +import java.util.EnumSet; import java.util.Map; import java.util.TreeMap; import java.util.UUID; @@ -45,6 +47,8 @@ public class InstanceConfig implements IInstanceConfig public UUID hostId() { return hostId; } private final Map params = new TreeMap<>(); + private EnumSet featureFlags; + private volatile InetAddressAndPort broadcastAddressAndPort; @Override @@ -97,15 +101,15 @@ private InstanceConfig(int num, .set("concurrent_compactors", 1) .set("memtable_heap_space_in_mb", 10) .set("commitlog_sync", "batch") - .set("storage_port", 7010) + .set("storage_port", 7012) .set("endpoint_snitch", SimpleSnitch.class.getName()) .set("seed_provider", new ParameterizedClass(SimpleSeedProvider.class.getName(), - Collections.singletonMap("seeds", "127.0.0.1:7010"))) + Collections.singletonMap("seeds", "127.0.0.1:7012"))) // required settings for dtest functionality .set("diagnostic_events_enabled", true) // legacy parameters .forceSet("commitlog_sync_batch_window_in_ms", 1.0); - // + this.featureFlags = EnumSet.noneOf(Feature.class); } private InstanceConfig(InstanceConfig copy) @@ -113,6 +117,18 @@ private InstanceConfig(InstanceConfig copy) this.num = copy.num; this.params.putAll(copy.params); this.hostId = copy.hostId; + this.featureFlags = copy.featureFlags; + } + + public InstanceConfig with(Feature featureFlag) + { + featureFlags.add(featureFlag); + return this; + } + + public boolean has(Feature featureFlag) + { + return featureFlags.contains(featureFlag); } public InstanceConfig set(String fieldName, Object value) @@ -203,13 +219,14 @@ public String getString(String name) return (String)params.get(name); } - public static InstanceConfig generate(int nodeNum, File root, String token) + public static InstanceConfig generate(int nodeNum, int subnet, File root, String token) { + String ipPrefix = "127.0." + subnet + "."; return new InstanceConfig(nodeNum, - "127.0.0." + nodeNum, - "127.0.0." + nodeNum, - "127.0.0." + nodeNum, - "127.0.0." + nodeNum, + ipPrefix + nodeNum, + ipPrefix + nodeNum, + ipPrefix + nodeNum, + ipPrefix + nodeNum, String.format("%s/node%d/saved_caches", root, nodeNum), new String[] { String.format("%s/node%d/data", root, nodeNum) }, String.format("%s/node%d/commitlog", root, nodeNum), diff --git a/test/distributed/org/apache/cassandra/distributed/impl/IsolatedExecutor.java b/test/distributed/org/apache/cassandra/distributed/impl/IsolatedExecutor.java index d82c9e49ab28..1d26c5dec36f 100644 --- a/test/distributed/org/apache/cassandra/distributed/impl/IsolatedExecutor.java +++ b/test/distributed/org/apache/cassandra/distributed/impl/IsolatedExecutor.java @@ -27,28 +27,36 @@ import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; import java.net.URLClassLoader; -import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.Future; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.ThreadFactory; +import java.util.concurrent.ThreadPoolExecutor; import java.util.concurrent.TimeUnit; import java.util.function.BiConsumer; import java.util.function.BiFunction; import java.util.function.Consumer; import java.util.function.Function; +import org.slf4j.LoggerFactory; + +import ch.qos.logback.classic.LoggerContext; import org.apache.cassandra.concurrent.NamedThreadFactory; import org.apache.cassandra.distributed.api.IIsolatedExecutor; +import org.apache.cassandra.utils.ExecutorUtils; public class IsolatedExecutor implements IIsolatedExecutor { final ExecutorService isolatedExecutor; + private final String name; private final ClassLoader classLoader; private final Method deserializeOnInstance; IsolatedExecutor(String name, ClassLoader classLoader) { + this.name = name; this.isolatedExecutor = Executors.newCachedThreadPool(new NamedThreadFactory("isolatedExecutor", Thread.NORM_PRIORITY, classLoader, new ThreadGroup(name))); this.classLoader = classLoader; this.deserializeOnInstance = lookupDeserializeOneObject(classLoader); @@ -57,9 +65,40 @@ public class IsolatedExecutor implements IIsolatedExecutor public Future shutdown() { isolatedExecutor.shutdown(); - ThrowingRunnable.toRunnable(((URLClassLoader) classLoader)::close).run(); - return CompletableFuture.runAsync(ThrowingRunnable.toRunnable(() -> isolatedExecutor.awaitTermination(60, TimeUnit.SECONDS)), - Executors.newSingleThreadExecutor()); + + /* Use a thread pool with a core pool size of zero to terminate the thread as soon as possible + ** so the instance class loader can be garbage collected. Uses a custom thread factory + ** rather than NamedThreadFactory to avoid calling FastThreadLocal.removeAll() in 3.0 and up + ** as it was observed crashing during test failures and made it harder to find the real cause. + */ + ThreadFactory threadFactory = (Runnable r) -> { + Thread t = new Thread(r, name + "_shutdown"); + t.setDaemon(true); + return t; + }; + ExecutorService shutdownExecutor = new ThreadPoolExecutor(0, Integer.MAX_VALUE, 0, TimeUnit.SECONDS, + new LinkedBlockingQueue(), threadFactory); + return shutdownExecutor.submit(() -> { + try + { + ExecutorUtils.awaitTermination(60, TimeUnit.SECONDS, isolatedExecutor); + + // Shutdown logging last - this is not ideal as the logging subsystem is initialized + // outsize of this class, however doing it this way provides access to the full + // logging system while termination is taking place. + LoggerContext loggerContext = (LoggerContext) LoggerFactory.getILoggerFactory(); + loggerContext.stop(); + + // Close the instance class loader after shutting down the isolatedExecutor and logging + // in case error handling triggers loading additional classes + ((URLClassLoader) classLoader).close(); + } + finally + { + shutdownExecutor.shutdownNow(); + } + return null; + }); } public CallableNoExcept> async(CallableNoExcept call) { return () -> isolatedExecutor.submit(call); } diff --git a/test/distributed/org/apache/cassandra/distributed/impl/MessageFilters.java b/test/distributed/org/apache/cassandra/distributed/impl/MessageFilters.java index 7f0f0fcf8d53..8ae12430a14a 100644 --- a/test/distributed/org/apache/cassandra/distributed/impl/MessageFilters.java +++ b/test/distributed/org/apache/cassandra/distributed/impl/MessageFilters.java @@ -21,44 +21,27 @@ import java.util.Arrays; import java.util.Set; import java.util.concurrent.CopyOnWriteArraySet; -import java.util.function.BiConsumer; import org.apache.cassandra.distributed.api.IInstance; -import org.apache.cassandra.distributed.api.IMessage; import org.apache.cassandra.distributed.api.IMessageFilters; -import org.apache.cassandra.distributed.api.ICluster; -import org.apache.cassandra.locator.InetAddressAndPort; -import org.apache.cassandra.net.MessagingService; +import org.apache.cassandra.net.Verb; public class MessageFilters implements IMessageFilters { - private final ICluster cluster; private final Set filters = new CopyOnWriteArraySet<>(); - public MessageFilters(AbstractCluster cluster) + public boolean permit(IInstance from, IInstance to, int verb) { - this.cluster = cluster; - } + if (from == null || to == null) + return false; // cannot deliver + int fromNum = from.config().num(); + int toNum = to.config().num(); - public BiConsumer filter(BiConsumer applyIfNotFiltered) - { - return (toAddress, message) -> - { - IInstance from = cluster.get(message.from()); - IInstance to = cluster.get(toAddress); - if (from == null || to == null) - return; // cannot deliver - int fromNum = from.config().num(); - int toNum = to.config().num(); - int verb = message.verb(); - for (Filter filter : filters) - { - if (filter.matches(fromNum, toNum, verb)) - return; - } + for (Filter filter : filters) + if (filter.matches(fromNum, toNum, verb)) + return false; - applyIfNotFiltered.accept(toAddress, message); - }; + return true; } public class Filter implements IMessageFilters.Filter @@ -162,12 +145,11 @@ public Filter drop() } } - @Override - public Builder verbs(MessagingService.Verb... verbs) + public Builder verbs(Verb... verbs) { int[] ids = new int[verbs.length]; for (int i = 0 ; i < verbs.length ; ++i) - ids[i] = verbs[i].getId(); + ids[i] = verbs[i].id; return new Builder(ids); } diff --git a/test/distributed/org/apache/cassandra/distributed/impl/Message.java b/test/distributed/org/apache/cassandra/distributed/impl/MessageImpl.java similarity index 79% rename from test/distributed/org/apache/cassandra/distributed/impl/Message.java rename to test/distributed/org/apache/cassandra/distributed/impl/MessageImpl.java index 6f8085c6e7df..e5c72c5438e7 100644 --- a/test/distributed/org/apache/cassandra/distributed/impl/Message.java +++ b/test/distributed/org/apache/cassandra/distributed/impl/MessageImpl.java @@ -22,15 +22,15 @@ import org.apache.cassandra.locator.InetAddressAndPort; // a container for simplifying the method signature for per-instance message handling/delivery -public class Message implements IMessage +public class MessageImpl implements IMessage { - private final int verb; - private final byte[] bytes; - private final int id; - private final int version; - private final InetAddressAndPort from; + public final int verb; + public final byte[] bytes; + public final long id; + public final int version; + public final InetAddressAndPort from; - public Message(int verb, byte[] bytes, int id, int version, InetAddressAndPort from) + public MessageImpl(int verb, byte[] bytes, long id, int version, InetAddressAndPort from) { this.verb = verb; this.bytes = bytes; @@ -39,31 +39,26 @@ public Message(int verb, byte[] bytes, int id, int version, InetAddressAndPort f this.from = from; } - @Override public int verb() { return verb; } - @Override public byte[] bytes() { return bytes; } - @Override public int id() { - return id; + return (int) id; } - @Override public int version() { return version; } - @Override public InetAddressAndPort from() { return from; diff --git a/test/distributed/org/apache/cassandra/distributed/test/DistributedReadWritePathTest.java b/test/distributed/org/apache/cassandra/distributed/test/DistributedReadWritePathTest.java index b95f166bbedc..56fd05cb4d54 100644 --- a/test/distributed/org/apache/cassandra/distributed/test/DistributedReadWritePathTest.java +++ b/test/distributed/org/apache/cassandra/distributed/test/DistributedReadWritePathTest.java @@ -26,9 +26,11 @@ import org.apache.cassandra.distributed.Cluster; import org.apache.cassandra.distributed.impl.IInvokableInstance; +import static org.apache.cassandra.distributed.api.Feature.NETWORK; import static org.junit.Assert.assertEquals; -import static org.apache.cassandra.net.MessagingService.Verb.READ_REPAIR; +import static org.apache.cassandra.net.Verb.READ_REPAIR_REQ; +import static org.apache.cassandra.net.OutboundConnections.LARGE_MESSAGE_THRESHOLD; public class DistributedReadWritePathTest extends DistributedTestBase { @@ -53,6 +55,26 @@ public void coordinatorReadTest() throws Throwable } } + @Test + public void largeMessageTest() throws Throwable + { + try (Cluster cluster = init(Cluster.create(2))) + { + cluster.schemaChange("CREATE TABLE " + KEYSPACE + ".tbl (pk int, ck int, v text, PRIMARY KEY (pk, ck))"); + StringBuilder builder = new StringBuilder(); + for (int i = 0; i < LARGE_MESSAGE_THRESHOLD ; i++) + builder.append('a'); + String s = builder.toString(); + cluster.coordinator(1).execute("INSERT INTO " + KEYSPACE + ".tbl (pk, ck, v) VALUES (1, 1, ?)", + ConsistencyLevel.ALL, + s); + assertRows(cluster.coordinator(1).execute("SELECT * FROM " + KEYSPACE + ".tbl WHERE pk = ?", + ConsistencyLevel.ALL, + 1), + row(1, 1, s)); + } + } + @Test public void coordinatorWriteTest() throws Throwable { @@ -109,7 +131,7 @@ public void failingReadRepairTest() throws Throwable assertRows(cluster.get(3).executeInternal("SELECT * FROM " + KEYSPACE + ".tbl WHERE pk = 1")); - cluster.verbs(READ_REPAIR).to(3).drop(); + cluster.verbs(READ_REPAIR_REQ).to(3).drop(); assertRows(cluster.coordinator(1).execute("SELECT * FROM " + KEYSPACE + ".tbl WHERE pk = 1", ConsistencyLevel.QUORUM), row(1, 1, 1)); @@ -122,7 +144,7 @@ public void failingReadRepairTest() throws Throwable @Test public void writeWithSchemaDisagreement() throws Throwable { - try (Cluster cluster = init(Cluster.create(3))) + try (Cluster cluster = init(Cluster.build(3).withConfig(config -> config.with(NETWORK)).start())) { cluster.schemaChange("CREATE TABLE " + KEYSPACE + ".tbl (pk int, ck int, v1 int, PRIMARY KEY (pk, ck))"); @@ -144,15 +166,15 @@ public void writeWithSchemaDisagreement() throws Throwable thrown = e; } - Assert.assertTrue(thrown.getMessage().contains("Exception occurred on node")); - Assert.assertTrue(thrown.getCause().getCause().getCause().getMessage().contains("Unknown column v2 during deserialization")); + Assert.assertTrue(thrown.getMessage().contains("INCOMPATIBLE_SCHEMA from 127.0.0.2")); + Assert.assertTrue(thrown.getMessage().contains("INCOMPATIBLE_SCHEMA from 127.0.0.3")); } } @Test public void readWithSchemaDisagreement() throws Throwable { - try (Cluster cluster = init(Cluster.create(3))) + try (Cluster cluster = init(Cluster.create(3, config -> config.with(NETWORK)))) { cluster.schemaChange("CREATE TABLE " + KEYSPACE + ".tbl (pk int, ck int, v1 int, PRIMARY KEY (pk, ck))"); @@ -174,8 +196,9 @@ public void readWithSchemaDisagreement() throws Throwable { thrown = e; } - Assert.assertTrue(thrown.getMessage().contains("Exception occurred on node")); - Assert.assertTrue(thrown.getCause().getCause().getCause().getMessage().contains("Unknown column v2 during deserialization")); + + Assert.assertTrue(thrown.getMessage().contains("INCOMPATIBLE_SCHEMA from 127.0.0.2")); + Assert.assertTrue(thrown.getMessage().contains("INCOMPATIBLE_SCHEMA from 127.0.0.3")); } } @@ -243,7 +266,7 @@ public void pagingWithRepairTest() throws Throwable public void pagingTests() throws Throwable { try (Cluster cluster = init(Cluster.create(3)); - Cluster singleNode = init(Cluster.create(1))) + Cluster singleNode = init(Cluster.build(1).withSubnet(1).start())) { cluster.schemaChange("CREATE TABLE " + KEYSPACE + ".tbl (pk int, ck int, v int, PRIMARY KEY (pk, ck))"); singleNode.schemaChange("CREATE TABLE " + KEYSPACE + ".tbl (pk int, ck int, v int, PRIMARY KEY (pk, ck))"); diff --git a/test/distributed/org/apache/cassandra/distributed/test/DistributedTestBase.java b/test/distributed/org/apache/cassandra/distributed/test/DistributedTestBase.java index 9e4579dcba47..35d36917b970 100644 --- a/test/distributed/org/apache/cassandra/distributed/test/DistributedTestBase.java +++ b/test/distributed/org/apache/cassandra/distributed/test/DistributedTestBase.java @@ -29,6 +29,7 @@ import org.junit.BeforeClass; import org.apache.cassandra.distributed.impl.AbstractCluster; +import org.apache.cassandra.distributed.impl.IsolatedExecutor; public class DistributedTestBase { @@ -41,10 +42,34 @@ public void afterEach() public static String KEYSPACE = "distributed_test_keyspace"; + public static void nativeLibraryWorkaround() + { + // Disable the Netty tcnative library otherwise the io.netty.internal.tcnative.CertificateCallbackTask, + // CertificateVerifierTask, SSLPrivateKeyMethodDecryptTask, SSLPrivateKeyMethodSignTask, + // SSLPrivateKeyMethodTask, and SSLTask hold a gcroot against the InstanceClassLoader. + System.setProperty("cassandra.disable_tcactive_openssl", "true"); + System.setProperty("io.netty.transport.noNative", "true"); + } + + public static void processReaperWorkaround() + { + // Make sure the 'process reaper' thread is initially created under the main classloader, + // otherwise it gets created with the contextClassLoader pointing to an InstanceClassLoader + // which prevents it from being garbage collected. + IsolatedExecutor.ThrowingRunnable.toRunnable(() -> new ProcessBuilder().command("true").start().waitFor()).run(); + } + @BeforeClass public static void setup() { System.setProperty("org.apache.cassandra.disable_mbean_registration", "true"); + nativeLibraryWorkaround(); + processReaperWorkaround(); + } + + static String withKeyspace(String replaceIn) + { + return String.format(replaceIn, KEYSPACE); } protected static > C init(C cluster) @@ -62,7 +87,7 @@ public static void assertRows(Object[][] actual, Object[]... expected) { Object[] expectedRow = expected[i]; Object[] actualRow = actual[i]; - Assert.assertTrue(rowsNotEqualErrorMessage(actual, expected), + Assert.assertTrue(rowsNotEqualErrorMessage(expected, actual), Arrays.equals(expectedRow, actualRow)); } } diff --git a/test/distributed/org/apache/cassandra/distributed/test/GossipSettlesTest.java b/test/distributed/org/apache/cassandra/distributed/test/GossipSettlesTest.java new file mode 100644 index 000000000000..d5542628780a --- /dev/null +++ b/test/distributed/org/apache/cassandra/distributed/test/GossipSettlesTest.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.distributed.test; + +import java.io.IOException; + +import org.junit.Test; + +import org.apache.cassandra.distributed.Cluster; + +import static org.apache.cassandra.distributed.api.Feature.GOSSIP; +import static org.apache.cassandra.distributed.api.Feature.NETWORK; + +public class GossipSettlesTest extends DistributedTestBase +{ + + @Test + public void test() + { + try (Cluster cluster = Cluster.create(3, config -> config.with(GOSSIP).with(NETWORK))) + { + } + catch (IOException e) + { + e.printStackTrace(); + } + } + +} diff --git a/test/distributed/org/apache/cassandra/distributed/test/LargeColumnTest.java b/test/distributed/org/apache/cassandra/distributed/test/LargeColumnTest.java new file mode 100644 index 000000000000..0c28d38ffde3 --- /dev/null +++ b/test/distributed/org/apache/cassandra/distributed/test/LargeColumnTest.java @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.distributed.test; + +import java.util.Random; +import java.util.concurrent.ThreadLocalRandom; +import java.util.concurrent.TimeUnit; + +import org.junit.Assert; +import org.junit.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.cassandra.db.ConsistencyLevel; +import org.apache.cassandra.distributed.Cluster; + +import static java.util.concurrent.TimeUnit.SECONDS; + +public class LargeColumnTest extends DistributedTestBase +{ + private static final Logger logger = LoggerFactory.getLogger(LargeColumnTest.class); + private static String str(int length, Random random, long seed) + { + random.setSeed(seed); + char[] chars = new char[length]; + int i = 0; + int s = 0; + long v = 0; + while (i < length) + { + if (s == 0) + { + v = random.nextLong(); + s = 8; + } + chars[i] = (char) (((v & 127) + 32) & 127); + v >>= 8; + --s; + ++i; + } + return new String(chars); + } + + private void testLargeColumns(int nodes, int columnSize, int rowCount) throws Throwable + { + Random random = new Random(); + long seed = ThreadLocalRandom.current().nextLong(); + logger.info("Using seed {}", seed); + + try (Cluster cluster = init(Cluster.build(nodes) + .withConfig(config -> + config.set("commitlog_segment_size_in_mb", (columnSize * 3) >> 20) + .set("internode_application_send_queue_reserve_endpoint_capacity_in_bytes", columnSize * 2) + .set("internode_application_send_queue_reserve_global_capacity_in_bytes", columnSize * 3) + .set("write_request_timeout_in_ms", SECONDS.toMillis(30L)) + .set("read_request_timeout_in_ms", SECONDS.toMillis(30L)) + .set("memtable_heap_space_in_mb", 1024) + ) + .start())) + { + cluster.schemaChange(String.format("CREATE TABLE %s.cf (k int, c text, PRIMARY KEY (k))", KEYSPACE)); + + for (int i = 0 ; i < rowCount ; ++i) + cluster.coordinator(1).execute(String.format("INSERT INTO %s.cf (k, c) VALUES (?, ?);", KEYSPACE), ConsistencyLevel.ALL, i, str(columnSize, random, seed | i)); + + for (int i = 0 ; i < rowCount ; ++i) + { + Object[][] results = cluster.coordinator(1).execute(String.format("SELECT k, c FROM %s.cf WHERE k = ?;", KEYSPACE), ConsistencyLevel.ALL, i); + Assert.assertTrue(str(columnSize, random, seed | i).equals(results[0][1])); + } + } + } + + @Test + public void test() throws Throwable + { + testLargeColumns(2, 16 << 20, 5); + } + +} diff --git a/test/distributed/org/apache/cassandra/distributed/test/RepairDigestTrackingTest.java b/test/distributed/org/apache/cassandra/distributed/test/RepairDigestTrackingTest.java index a987ea30cd81..0c394550f775 100644 --- a/test/distributed/org/apache/cassandra/distributed/test/RepairDigestTrackingTest.java +++ b/test/distributed/org/apache/cassandra/distributed/test/RepairDigestTrackingTest.java @@ -22,7 +22,6 @@ import java.io.Serializable; import java.util.EnumSet; import java.util.Map; -import java.util.Set; import org.junit.Assert; import org.junit.Test; @@ -60,10 +59,10 @@ public void testInconsistenciesFound() throws Throwable } cluster.get(1).runOnInstance(() -> - Keyspace.open(KEYSPACE).getColumnFamilyStore("tbl").forceBlockingFlush() + Keyspace.open(KEYSPACE).getColumnFamilyStore("tbl").forceBlockingFlushToSSTable() ); cluster.get(2).runOnInstance(() -> - Keyspace.open(KEYSPACE).getColumnFamilyStore("tbl").forceBlockingFlush() + Keyspace.open(KEYSPACE).getColumnFamilyStore("tbl").forceBlockingFlushToSSTable() ); for (int i = 10; i < 20; i++) @@ -74,10 +73,10 @@ public void testInconsistenciesFound() throws Throwable } cluster.get(1).runOnInstance(() -> - Keyspace.open(KEYSPACE).getColumnFamilyStore("tbl").forceBlockingFlush() + Keyspace.open(KEYSPACE).getColumnFamilyStore("tbl").forceBlockingFlushToSSTable() ); cluster.get(2).runOnInstance(() -> - Keyspace.open(KEYSPACE).getColumnFamilyStore("tbl").forceBlockingFlush() + Keyspace.open(KEYSPACE).getColumnFamilyStore("tbl").forceBlockingFlushToSSTable() ); cluster.get(1).runOnInstance(() -> diff --git a/test/distributed/org/apache/cassandra/distributed/test/RepairTest.java b/test/distributed/org/apache/cassandra/distributed/test/RepairTest.java new file mode 100644 index 000000000000..1c488aaf5cfa --- /dev/null +++ b/test/distributed/org/apache/cassandra/distributed/test/RepairTest.java @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.distributed.test; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Map; +import java.util.function.Consumer; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import org.junit.BeforeClass; +import org.junit.Ignore; +import org.junit.Test; + +import org.apache.cassandra.distributed.Cluster; +import org.apache.cassandra.distributed.impl.InstanceConfig; +import org.apache.cassandra.service.StorageService; +import org.apache.cassandra.utils.concurrent.SimpleCondition; +import org.apache.cassandra.utils.progress.ProgressEventType; + +import static java.util.concurrent.TimeUnit.MINUTES; +import static org.apache.cassandra.distributed.impl.ExecUtil.rethrow; +import static org.apache.cassandra.distributed.api.Feature.GOSSIP; +import static org.apache.cassandra.distributed.api.Feature.NETWORK; + +public class RepairTest extends DistributedTestBase +{ + private static final String insert = withKeyspace("INSERT INTO %s.test (k, c1, c2) VALUES (?, 'value1', 'value2');"); + private static final String query = withKeyspace("SELECT k, c1, c2 FROM %s.test WHERE k = ?;"); + private static Cluster cluster; + + private static void insert(Cluster cluster, int start, int end, int ... nodes) + { + for (int i = start ; i < end ; ++i) + for (int node : nodes) + cluster.get(node).executeInternal(insert, Integer.toString(i)); + } + + private static void verify(Cluster cluster, int start, int end, int ... nodes) + { + for (int i = start ; i < end ; ++i) + { + for (int node = 1 ; node <= cluster.size() ; ++node) + { + Object[][] rows = cluster.get(node).executeInternal(query, Integer.toString(i)); + if (Arrays.binarySearch(nodes, node) >= 0) + assertRows(rows, new Object[] { Integer.toString(i), "value1", "value2" }); + else + assertRows(rows); + } + } + } + + private static void flush(Cluster cluster, int ... nodes) + { + for (int node : nodes) + cluster.get(node).runOnInstance(rethrow(() -> StorageService.instance.forceKeyspaceFlush(KEYSPACE))); + } + + private static Cluster create(Consumer configModifier) throws IOException + { + configModifier = configModifier.andThen( + config -> config.set("hinted_handoff_enabled", false) + .set("commitlog_sync_batch_window_in_ms", 5) + .with(NETWORK) + .with(GOSSIP) + ); + + Cluster cluster = init(Cluster.build(3).withConfig(configModifier).start()); + return cluster; + } + + private void repair(Cluster cluster, Map options) + { + cluster.get(1).runOnInstance(rethrow(() -> { + SimpleCondition await = new SimpleCondition(); + StorageService.instance.repair(KEYSPACE, options, ImmutableList.of((tag, event) -> { + if (event.getType() == ProgressEventType.COMPLETE) + await.signalAll(); + })).right.get(); + await.await(1L, MINUTES); + })); + } + + void populate(Cluster cluster, boolean compression) + { + try + { + cluster.schemaChange(withKeyspace("DROP TABLE IF EXISTS %s.test;")); + cluster.schemaChange(withKeyspace("CREATE TABLE %s.test (k text, c1 text, c2 text, PRIMARY KEY (k))") + + (compression == false ? " WITH compression = {'enabled' : false};" : ";")); + + insert(cluster, 0, 1000, 1, 2, 3); + flush(cluster, 1); + insert(cluster, 1000, 1001, 1, 2); + insert(cluster, 1001, 2001, 1, 2, 3); + flush(cluster, 1, 2, 3); + + verify(cluster, 0, 1000, 1, 2, 3); + verify(cluster, 1000, 1001, 1, 2); + verify(cluster, 1001, 2001, 1, 2, 3); + } + catch (Throwable t) + { + cluster.close(); + throw t; + } + + } + + void simpleRepair(Cluster cluster, boolean sequential, boolean compression) throws IOException + { + populate(cluster, compression); + repair(cluster, ImmutableMap.of("parallelism", sequential ? "sequential" : "parallel")); + verify(cluster, 0, 2001, 1, 2, 3); + } + + @BeforeClass + public static void setupCluster() throws IOException + { + cluster = create(config -> {}); + } + + @Ignore("Test requires CASSANDRA-13938 to be merged") + public void testSimpleSequentialRepairDefaultCompression() throws IOException + { + simpleRepair(cluster, true, true); + } + + @Test + public void testSimpleSequentialRepairCompressionOff() throws IOException + { + simpleRepair(cluster, true, false); + } +} diff --git a/test/distributed/org/apache/cassandra/distributed/test/ResourceLeakTest.java b/test/distributed/org/apache/cassandra/distributed/test/ResourceLeakTest.java new file mode 100644 index 000000000000..4bfbdc9c183f --- /dev/null +++ b/test/distributed/org/apache/cassandra/distributed/test/ResourceLeakTest.java @@ -0,0 +1,202 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.distributed.test; + +import java.io.File; +import java.io.IOException; +import java.lang.management.ManagementFactory; +import java.nio.file.FileSystems; +import java.nio.file.Path; +import java.sql.Date; +import java.text.SimpleDateFormat; +import java.time.Instant; +import java.util.List; +import java.util.function.Consumer; +import javax.management.MBeanServer; + +import org.junit.Ignore; +import org.junit.Test; + +import com.sun.management.HotSpotDiagnosticMXBean; +import org.apache.cassandra.db.ConsistencyLevel; +import org.apache.cassandra.db.Keyspace; +import org.apache.cassandra.distributed.Cluster; +import org.apache.cassandra.distributed.impl.InstanceConfig; +import org.apache.cassandra.gms.Gossiper; +import org.apache.cassandra.service.CassandraDaemon; +import org.apache.cassandra.utils.FBUtilities; +import org.apache.cassandra.utils.SigarLibrary; + +import static org.apache.cassandra.distributed.api.Feature.GOSSIP; +import static org.apache.cassandra.distributed.api.Feature.NETWORK; + +/* Resource Leak Test - useful when tracking down issues with in-JVM framework cleanup. + * All objects referencing the InstanceClassLoader need to be garbage collected or + * the JVM runs out of metaspace. This test also calls out to lsof to check which + * file handles are still opened. + * + * This is intended to be a burn type test where it is run outside of the test suites + * when a problem is detected (like OutOfMetaspace exceptions). + * + * Currently this test demonstrates that the InstanceClassLoader is cleaned up (load up + * the final hprof and check that the class loaders are not reachable from a GC root), + * but it shows that the file handles for Data/Index files are being leaked. + */ +@Ignore +public class ResourceLeakTest extends DistributedTestBase +{ + // Parameters to adjust while hunting for leaks + final int numTestLoops = 1; // Set this value high to crash on leaks, or low when tracking down an issue. + final boolean dumpEveryLoop = false; // Dump heap & possibly files every loop + final boolean dumpFileHandles = false; // Call lsof whenever dumping resources + final boolean forceCollection = false; // Whether to explicitly force finalization/gc for smaller heap dumps + final long finalWaitMillis = 0l; // Number of millis to wait before final resource dump to give gc a chance + + static final SimpleDateFormat format = new SimpleDateFormat("yyyyMMddHHmmss"); + static final String when = format.format(Date.from(Instant.now())); + + static String outputFilename(String base, String description, String extension) + { + Path p = FileSystems.getDefault().getPath("build", "test", + String.join("-", when, base, description) + extension); + return p.toString(); + } + + /** + * Retrieves the process ID or null if the process ID cannot be retrieved. + * @return the process ID or null if the process ID cannot be retrieved. + * + * (Duplicated from HeapUtils to avoid refactoring older releases where this test is useful). + */ + private static Long getProcessId() + { + // Once Java 9 is ready the process API should provide a better way to get the process ID. + long pid = SigarLibrary.instance.getPid(); + + if (pid >= 0) + return Long.valueOf(pid); + + return getProcessIdFromJvmName(); + } + + /** + * Retrieves the process ID from the JVM name. + * @return the process ID or null if the process ID cannot be retrieved. + */ + private static Long getProcessIdFromJvmName() + { + // the JVM name in Oracle JVMs is: '@' but this might not be the case on all JVMs + String jvmName = ManagementFactory.getRuntimeMXBean().getName(); + try + { + return Long.parseLong(jvmName.split("@")[0]); + } + catch (NumberFormatException e) + { + // ignore + } + return null; + } + + static void dumpHeap(String description, boolean live) throws IOException + { + MBeanServer server = ManagementFactory.getPlatformMBeanServer(); + HotSpotDiagnosticMXBean mxBean = ManagementFactory.newPlatformMXBeanProxy( + server, "com.sun.management:type=HotSpotDiagnostic", HotSpotDiagnosticMXBean.class); + mxBean.dumpHeap(outputFilename("heap", description, ".hprof"), live); + } + + static void dumpOpenFiles(String description) throws IOException, InterruptedException + { + long pid = getProcessId(); + ProcessBuilder map = new ProcessBuilder("/usr/sbin/lsof", "-p", Long.toString(pid)); + File output = new File(outputFilename("lsof", description, ".txt")); + map.redirectOutput(output); + map.redirectErrorStream(true); + map.start().waitFor(); + } + + void dumpResources(String description) throws IOException, InterruptedException + { + dumpHeap(description, false); + if (dumpFileHandles) + { + dumpOpenFiles(description); + } + } + + void doTest(int numClusterNodes, Consumer updater) throws Throwable + { + for (int loop = 0; loop < numTestLoops; loop++) + { + try (Cluster cluster = Cluster.build(numClusterNodes).withConfig(updater).start()) + { + if (cluster.get(1).config().has(GOSSIP)) // Wait for gossip to settle on the seed node + cluster.get(1).runOnInstance(() -> Gossiper.waitToSettle()); + + init(cluster); + String tableName = "tbl" + loop; + cluster.schemaChange("CREATE TABLE " + KEYSPACE + "." + tableName + " (pk int, ck int, v int, PRIMARY KEY (pk, ck))"); + cluster.coordinator(1).execute("INSERT INTO " + KEYSPACE + "." + tableName + "(pk,ck,v) VALUES (0,0,0)", ConsistencyLevel.ALL); + cluster.get(1).callOnInstance(() -> FBUtilities.waitOnFutures(Keyspace.open(KEYSPACE).flush())); + if (dumpEveryLoop) + { + dumpResources(String.format("loop%03d", loop)); + } + } + catch (Throwable tr) + { + System.out.println("Dumping resources for exception: " + tr.getMessage()); + tr.printStackTrace(); + dumpResources("exception"); + } + if (forceCollection) + { + System.runFinalization(); + System.gc(); + } + } + } + + @Test + public void looperTest() throws Throwable + { + doTest(1, config -> {}); + if (forceCollection) + { + System.runFinalization(); + System.gc(); + Thread.sleep(finalWaitMillis); + } + dumpResources("final"); + } + + @Test + public void looperGossipNetworkTest() throws Throwable + { + doTest(2, config -> config.with(GOSSIP).with(NETWORK)); + if (forceCollection) + { + System.runFinalization(); + System.gc(); + Thread.sleep(finalWaitMillis); + } + dumpResources("final-gossip-network"); + } +} diff --git a/test/distributed/org/apache/cassandra/distributed/test/StreamingTest.java b/test/distributed/org/apache/cassandra/distributed/test/StreamingTest.java new file mode 100644 index 000000000000..22cd5907c7c5 --- /dev/null +++ b/test/distributed/org/apache/cassandra/distributed/test/StreamingTest.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.distributed.test; + +import java.util.Arrays; +import java.util.Comparator; + +import org.junit.Assert; +import org.junit.Test; + +import org.apache.cassandra.distributed.Cluster; +import org.apache.cassandra.service.StorageService; + +import static org.apache.cassandra.distributed.api.Feature.NETWORK; + +public class StreamingTest extends DistributedTestBase +{ + + private void testStreaming(int nodes, int replicationFactor, int rowCount, String compactionStrategy) throws Throwable + { + try (Cluster cluster = Cluster.create(nodes, config -> config.with(NETWORK))) + { + cluster.schemaChange("CREATE KEYSPACE " + KEYSPACE + " WITH replication = {'class': 'SimpleStrategy', 'replication_factor': " + replicationFactor + "};"); + cluster.schemaChange(String.format("CREATE TABLE %s.cf (k text, c1 text, c2 text, PRIMARY KEY (k)) WITH compaction = {'class': '%s', 'enabled': 'true'}", KEYSPACE, compactionStrategy)); + + for (int i = 0 ; i < rowCount ; ++i) + { + for (int n = 1 ; n < nodes ; ++n) + cluster.get(n).executeInternal(String.format("INSERT INTO %s.cf (k, c1, c2) VALUES (?, 'value1', 'value2');", KEYSPACE), Integer.toString(i)); + } + + cluster.get(nodes).executeInternal("TRUNCATE system.available_ranges;"); + { + Object[][] results = cluster.get(nodes).executeInternal(String.format("SELECT k, c1, c2 FROM %s.cf;", KEYSPACE)); + Assert.assertEquals(0, results.length); + } + + cluster.get(nodes).runOnInstance(() -> StorageService.instance.rebuild(null, KEYSPACE, null, null)); + { + Object[][] results = cluster.get(nodes).executeInternal(String.format("SELECT k, c1, c2 FROM %s.cf;", KEYSPACE)); + Assert.assertEquals(1000, results.length); + Arrays.sort(results, Comparator.comparingInt(a -> Integer.parseInt((String) a[0]))); + for (int i = 0 ; i < results.length ; ++i) + { + Assert.assertEquals(Integer.toString(i), results[i][0]); + Assert.assertEquals("value1", results[i][1]); + Assert.assertEquals("value2", results[i][2]); + } + } + } + } + + @Test + public void test() throws Throwable + { + testStreaming(2, 2, 1000, "LeveledCompactionStrategy"); + } + +} diff --git a/test/distributed/org/apache/cassandra/distributed/util/PyDtest.java b/test/distributed/org/apache/cassandra/distributed/util/PyDtest.java new file mode 100644 index 000000000000..3b2425f74d8c --- /dev/null +++ b/test/distributed/org/apache/cassandra/distributed/util/PyDtest.java @@ -0,0 +1,186 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.distributed.util; + +import java.util.Arrays; +import java.util.List; + +public class PyDtest +{ + + public static class CreateCf + { + final String keyspace; + final String name; + String primaryKey, clustering, keyType, speculativeRetry, compression, validation, compactionStrategy; + Float readRepair; + Integer gcGrace; + List columns; + Boolean compactStorage; + + public CreateCf(String keyspace, String name) + { + this.keyspace = keyspace; + this.name = name; + } + + public CreateCf withPrimaryKey(String primaryKey) + { + this.primaryKey = primaryKey; + return this; + } + + public CreateCf withClustering(String clustering) + { + this.clustering = clustering; + return this; + } + + public CreateCf withKeyType(String keyType) + { + this.keyType = keyType; + return this; + } + + public CreateCf withSpeculativeRetry(String speculativeRetry) + { + this.speculativeRetry = speculativeRetry; + return this; + } + + public CreateCf withCompression(String compression) + { + this.compression = compression; + return this; + } + + public CreateCf withValidation(String validation) + { + this.validation = validation; + return this; + } + + public CreateCf withCompactionStrategy(String compactionStrategy) + { + this.compactionStrategy = compactionStrategy; + return this; + } + + public CreateCf withReadRepair(Float readRepair) + { + this.readRepair = readRepair; + return this; + } + + public CreateCf withGcGrace(Integer gcGrace) + { + this.gcGrace = gcGrace; + return this; + } + + public CreateCf withColumns(List columns) + { + this.columns = columns; + return this; + } + + public CreateCf withColumns(String ... columns) + { + this.columns = Arrays.asList(columns); + return this; + } + + public CreateCf withCompactStorage(Boolean compactStorage) + { + this.compactStorage = compactStorage; + return this; + } + + public String build() + { + if (keyspace == null) + throw new IllegalArgumentException(); + if (name == null) + throw new IllegalArgumentException(); + if (keyType == null) + keyType = "varchar"; + if (validation == null) + validation = "UTF8Type"; + if (compactionStrategy == null) + compactionStrategy = "SizeTieredCompactionStrategy"; + if (compactStorage == null) + compactStorage = false; + + + String compaction_fragment = String.format("compaction = {'class': '%s', 'enabled': 'true'}", compactionStrategy); + + String query; + String additional_columns = ""; + if (columns == null) + { + query = String.format("CREATE COLUMNFAMILY %s.%s (key %s, c varchar, v varchar, PRIMARY KEY(key, c)) WITH comment=\'test cf\'", keyspace, name, keyType); + } + else + { + for (String pair : columns) + { + String[] split = pair.split(":"); + String key = split[0]; + String type = split[1]; + additional_columns += ", " + key + " " + type; + } + + if (primaryKey != null) + query = String.format("CREATE COLUMNFAMILY %s.%s (key %s%s, PRIMARY KEY(%s)) WITH comment=\'test cf\'", keyspace, name, keyType, additional_columns, primaryKey); + else + query = String.format("CREATE COLUMNFAMILY %s.%s (key %s PRIMARY KEY%s) WITH comment=\'test cf\'", keyspace, name, keyType, additional_columns); + } + + + if (compaction_fragment != null) + query += " AND " + compaction_fragment; + + if (clustering != null) + query += String.format(" AND CLUSTERING ORDER BY (%s)", clustering); + + if (compression != null) + query += String.format(" AND compression = { \'sstable_compression\': \'%sCompressor\' }", compression); + else + query += " AND compression = {}"; + + if (readRepair != null) + query += String.format(" AND read_repair_chance=%f AND dclocal_read_repair_chance=%f", readRepair, readRepair); + if (gcGrace != null) + query += String.format(" AND gc_grace_seconds=%d", gcGrace); + if (speculativeRetry != null) + query += String.format(" AND speculative_retry=\'%s\'", speculativeRetry); + + if (compactStorage != null && compactStorage) + query += " AND COMPACT STORAGE"; + + return query; + } + } + + public static CreateCf createCf(String keyspace, String name) + { + return new CreateCf(keyspace, name); + } + +} diff --git a/test/long/org/apache/cassandra/db/commitlog/CommitLogStressTest.java b/test/long/org/apache/cassandra/db/commitlog/CommitLogStressTest.java index e2c6e33789bc..8f217439d3c1 100644 --- a/test/long/org/apache/cassandra/db/commitlog/CommitLogStressTest.java +++ b/test/long/org/apache/cassandra/db/commitlog/CommitLogStressTest.java @@ -35,15 +35,12 @@ import org.junit.BeforeClass; import org.junit.Ignore; import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; import org.junit.runners.Parameterized.Parameters; import io.netty.util.concurrent.FastThreadLocalThread; import org.apache.cassandra.SchemaLoader; import org.apache.cassandra.Util; import org.apache.cassandra.UpdateBuilder; -import org.apache.cassandra.config.Config.CommitLogSync; import org.apache.cassandra.config.*; import org.apache.cassandra.db.Mutation; import org.apache.cassandra.db.marshal.UTF8Type; @@ -287,7 +284,7 @@ private void verifySizes(CommitLog commitLog) List logFileNames = commitLog.getActiveSegmentNames(); Map ratios = commitLog.getActiveSegmentCompressionRatios(); - Collection segments = commitLog.segmentManager.getActiveSegments(); + Collection segments = commitLog.segmentManager.getSegmentsForUnflushedTables(); for (CommitLogSegment segment : segments) { diff --git a/test/long/org/apache/cassandra/db/compaction/LongCompactionsTest.java b/test/long/org/apache/cassandra/db/compaction/LongCompactionsTest.java index fe8cdc2fd92b..6e247936a9ae 100644 --- a/test/long/org/apache/cassandra/db/compaction/LongCompactionsTest.java +++ b/test/long/org/apache/cassandra/db/compaction/LongCompactionsTest.java @@ -165,7 +165,7 @@ public void testStandardColumnCompactions() inserted.add(key); } - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); CompactionsTest.assertMaxTimestamp(cfs, maxTimestampExpected); assertEquals(inserted.toString(), inserted.size(), Util.getAll(Util.cmd(cfs).build()).size()); diff --git a/test/long/org/apache/cassandra/db/compaction/LongLeveledCompactionStrategyTest.java b/test/long/org/apache/cassandra/db/compaction/LongLeveledCompactionStrategyTest.java index f8f94a0aadfa..62c4099d1803 100644 --- a/test/long/org/apache/cassandra/db/compaction/LongLeveledCompactionStrategyTest.java +++ b/test/long/org/apache/cassandra/db/compaction/LongLeveledCompactionStrategyTest.java @@ -29,7 +29,6 @@ import org.apache.cassandra.io.sstable.ISSTableScanner; import org.apache.cassandra.io.sstable.format.SSTableReader; import org.junit.BeforeClass; -import org.junit.Ignore; import org.junit.Test; import org.apache.cassandra.SchemaLoader; @@ -40,7 +39,6 @@ import org.apache.cassandra.schema.CompactionParams; import org.apache.cassandra.schema.KeyspaceParams; import org.apache.cassandra.service.ActiveRepairService; -import org.apache.cassandra.service.StorageService; import org.apache.cassandra.utils.FBUtilities; import static org.junit.Assert.assertFalse; @@ -164,7 +162,7 @@ public void testLeveledScanner() throws Exception } //Flush sstable - store.forceBlockingFlush(); + store.forceBlockingFlushToSSTable(); store.runWithCompactionsDisabled(new Callable() { @@ -263,7 +261,7 @@ private void populateSSTables(ColumnFamilyStore store) Mutation rm = new Mutation(builder.build()); rm.apply(); - store.forceBlockingFlush(); + store.forceBlockingFlushToSSTable(); } } } diff --git a/test/long/org/apache/cassandra/locator/DynamicEndpointSnitchLongTest.java b/test/long/org/apache/cassandra/locator/DynamicEndpointSnitchLongTest.java index 94a3bd3569f1..2e27738efc18 100644 --- a/test/long/org/apache/cassandra/locator/DynamicEndpointSnitchLongTest.java +++ b/test/long/org/apache/cassandra/locator/DynamicEndpointSnitchLongTest.java @@ -30,6 +30,8 @@ import org.apache.cassandra.utils.FBUtilities; +import static java.util.concurrent.TimeUnit.MILLISECONDS; + public class DynamicEndpointSnitchLongTest { static @@ -101,7 +103,7 @@ public void run() { Replica host = hosts.get(random.nextInt(hosts.size())); int score = random.nextInt(SCORE_RANGE); - dsnitch.receiveTiming(host.endpoint(), score); + dsnitch.receiveTiming(host.endpoint(), score, MILLISECONDS); } } } diff --git a/test/microbench/org/apache/cassandra/test/microbench/CompactionBench.java b/test/microbench/org/apache/cassandra/test/microbench/CompactionBench.java index 41220a2a655c..4b817305439e 100644 --- a/test/microbench/org/apache/cassandra/test/microbench/CompactionBench.java +++ b/test/microbench/org/apache/cassandra/test/microbench/CompactionBench.java @@ -70,13 +70,13 @@ public void setup() throws Throwable execute(writeStatement, i, i, i ); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); System.err.println("Writing 50k again..."); for (long i = 0; i < 50000; i++) execute(writeStatement, i, i, i ); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); cfs.snapshot("originals"); diff --git a/test/microbench/org/apache/cassandra/test/microbench/MessageOutBench.java b/test/microbench/org/apache/cassandra/test/microbench/MessageOutBench.java index 2aec66883a8b..4ab607f49c30 100644 --- a/test/microbench/org/apache/cassandra/test/microbench/MessageOutBench.java +++ b/test/microbench/org/apache/cassandra/test/microbench/MessageOutBench.java @@ -19,31 +19,23 @@ package org.apache.cassandra.test.microbench; import java.io.IOException; -import java.util.Collections; import java.util.EnumMap; import java.util.Map; import java.util.UUID; import java.util.concurrent.TimeUnit; -import java.util.function.BiConsumer; -import com.google.common.collect.ImmutableList; import com.google.common.net.InetAddresses; -import com.google.common.primitives.Shorts; import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; import org.apache.cassandra.config.DatabaseDescriptor; -import org.apache.cassandra.exceptions.RequestFailureReason; +import org.apache.cassandra.io.util.DataInputBuffer; +import org.apache.cassandra.io.util.DataOutputBuffer; import org.apache.cassandra.locator.InetAddressAndPort; -import org.apache.cassandra.net.MessageIn; -import org.apache.cassandra.net.MessageOut; +import org.apache.cassandra.net.Message; import org.apache.cassandra.net.MessagingService; -import org.apache.cassandra.net.ParameterType; -import org.apache.cassandra.net.async.BaseMessageInHandler; -import org.apache.cassandra.net.async.ByteBufDataOutputPlus; -import org.apache.cassandra.net.async.MessageInHandler; -import org.apache.cassandra.net.async.MessageInHandlerPre40; -import org.apache.cassandra.utils.NanoTimeToCurrentTimeMillis; +import org.apache.cassandra.net.NoPayload; +import org.apache.cassandra.net.ParamType; import org.apache.cassandra.utils.UUIDGen; import org.openjdk.jmh.annotations.Benchmark; import org.openjdk.jmh.annotations.BenchmarkMode; @@ -57,7 +49,7 @@ import org.openjdk.jmh.annotations.State; import org.openjdk.jmh.annotations.Warmup; -import static org.apache.cassandra.net.async.OutboundConnectionIdentifier.ConnectionType.SMALL_MESSAGE; +import static org.apache.cassandra.net.Verb.ECHO_REQ; @State(Scope.Thread) @Warmup(iterations = 4, time = 1, timeUnit = TimeUnit.SECONDS) @@ -70,60 +62,51 @@ public class MessageOutBench @Param({ "true", "false" }) private boolean withParams; - private MessageOut msgOut; + private Message msgOut; private ByteBuf buf; - BaseMessageInHandler handler40; - BaseMessageInHandler handlerPre40; + private InetAddressAndPort addr; @Setup public void setup() { DatabaseDescriptor.daemonInitialization(); - InetAddressAndPort addr = InetAddressAndPort.getByAddress(InetAddresses.forString("127.0.73.101")); UUID uuid = UUIDGen.getTimeUUID(); - Map parameters = new EnumMap<>(ParameterType.class); + Map parameters = new EnumMap<>(ParamType.class); if (withParams) { - parameters.put(ParameterType.FAILURE_RESPONSE, MessagingService.ONE_BYTE); - parameters.put(ParameterType.FAILURE_REASON, Shorts.checkedCast(RequestFailureReason.READ_TOO_MANY_TOMBSTONES.code)); - parameters.put(ParameterType.TRACE_SESSION, uuid); + parameters.put(ParamType.TRACE_SESSION, uuid); } - msgOut = new MessageOut<>(addr, MessagingService.Verb.ECHO, null, null, ImmutableList.of(), SMALL_MESSAGE); + addr = InetAddressAndPort.getByAddress(InetAddresses.forString("127.0.73.101")); + msgOut = Message.builder(ECHO_REQ, NoPayload.noPayload) + .from(addr) + .build(); buf = Unpooled.buffer(1024, 1024); // 1k should be enough for everybody! - - handler40 = new MessageInHandler(addr, MessagingService.VERSION_40, messageConsumer); - handlerPre40 = new MessageInHandlerPre40(addr, MessagingService.VERSION_30, messageConsumer); } @Benchmark public int serialize40() throws Exception { - return serialize(MessagingService.VERSION_40, handler40); + return serialize(MessagingService.VERSION_40); } - private int serialize(int messagingVersion, BaseMessageInHandler handler) throws Exception + private int serialize(int messagingVersion) throws IOException { - buf.resetReaderIndex(); - buf.resetWriterIndex(); - buf.writeInt(MessagingService.PROTOCOL_MAGIC); - buf.writeInt(42); // this is the id - buf.writeInt((int) NanoTimeToCurrentTimeMillis.convert(System.nanoTime())); - - msgOut.serialize(new ByteBufDataOutputPlus(buf), messagingVersion); - handler.decode(null, buf, Collections.emptyList()); - return msgOut.serializedSize(messagingVersion); + try (DataOutputBuffer out = new DataOutputBuffer()) + { + Message.serializer.serialize(Message.builder(msgOut).withCreatedAt(System.nanoTime()).withId(42).build(), + out, messagingVersion); + DataInputBuffer in = new DataInputBuffer(out.buffer(), false); + Message.serializer.deserialize(in, addr, messagingVersion); + return msgOut.serializedSize(messagingVersion); + } } @Benchmark public int serializePre40() throws Exception { - return serialize(MessagingService.VERSION_30, handlerPre40); + return serialize(MessagingService.VERSION_30); } - - private final BiConsumer messageConsumer = (messageIn, integer) -> - { - }; } diff --git a/test/microbench/org/apache/cassandra/test/microbench/MutationBench.java b/test/microbench/org/apache/cassandra/test/microbench/MutationBench.java index 4a0e64643c80..074e183f2a9b 100644 --- a/test/microbench/org/apache/cassandra/test/microbench/MutationBench.java +++ b/test/microbench/org/apache/cassandra/test/microbench/MutationBench.java @@ -34,8 +34,6 @@ import org.apache.cassandra.io.util.DataInputBuffer; import org.apache.cassandra.io.util.DataOutputBuffer; import org.apache.cassandra.io.util.DataOutputBufferFixed; -import org.apache.cassandra.net.MessageIn; -import org.apache.cassandra.net.MessageOut; import org.apache.cassandra.net.MessagingService; import org.apache.cassandra.schema.KeyspaceMetadata; import org.apache.cassandra.schema.KeyspaceParams; @@ -73,7 +71,6 @@ public class MutationBench static String keyspace = "keyspace1"; private Mutation mutation; - private MessageOut messageOut; private ByteBuffer buffer; private DataOutputBuffer outputBuffer; @@ -83,7 +80,7 @@ public class MutationBench @State(Scope.Thread) public static class ThreadState { - MessageIn in; + Mutation in; int counter = 0; } @@ -103,19 +100,18 @@ public void setup() throws IOException Schema.instance.load(ksm.withSwapped(ksm.tables.with(metadata))); mutation = (Mutation)UpdateBuilder.create(metadata, 1L).newRow(1L).add("commentid", 32L).makeMutation(); - messageOut = mutation.createMessage(); - buffer = ByteBuffer.allocate(messageOut.serializedSize(MessagingService.current_version)); + buffer = ByteBuffer.allocate((int) Mutation.serializer.serializedSize(mutation, MessagingService.current_version)); outputBuffer = new DataOutputBufferFixed(buffer); inputBuffer = new DataInputBuffer(buffer, false); - messageOut.serialize(outputBuffer, MessagingService.current_version); + Mutation.serializer.serialize(mutation, outputBuffer, MessagingService.current_version); } @Benchmark public void serialize(ThreadState state) throws IOException { buffer.rewind(); - messageOut.serialize(outputBuffer, MessagingService.current_version); + Mutation.serializer.serialize(mutation, outputBuffer, MessagingService.current_version); state.counter++; } @@ -123,7 +119,7 @@ public void serialize(ThreadState state) throws IOException public void deserialize(ThreadState state) throws IOException { buffer.rewind(); - state.in = MessageIn.read(inputBuffer, MessagingService.current_version, 0); + state.in = Mutation.serializer.deserialize(inputBuffer, MessagingService.current_version); state.counter++; } diff --git a/test/microbench/org/apache/cassandra/test/microbench/PreaggregatedByteBufsBench.java b/test/microbench/org/apache/cassandra/test/microbench/PreaggregatedByteBufsBench.java new file mode 100644 index 000000000000..9971cc5b4a48 --- /dev/null +++ b/test/microbench/org/apache/cassandra/test/microbench/PreaggregatedByteBufsBench.java @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.test.microbench; + +import java.util.concurrent.TimeUnit; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.embedded.EmbeddedChannel; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.Warmup; + +@State(Scope.Thread) +@Warmup(iterations = 4, time = 1, timeUnit = TimeUnit.SECONDS) +@Measurement(iterations = 8, time = 4, timeUnit = TimeUnit.SECONDS) +@Fork(value = 1,jvmArgsAppend = "-Xmx512M") +@OutputTimeUnit(TimeUnit.NANOSECONDS) +@BenchmarkMode(Mode.SampleTime) +public class PreaggregatedByteBufsBench +{ + @Param({ "1000", "2500", "5000", "10000", "20000", "40000"}) + private int len; + + private static final int subBufferCount = 64; + + private EmbeddedChannel channel; + + @Setup + public void setUp() + { + channel = new EmbeddedChannel(); + } + + @Benchmark + public boolean oneBigBuf() + { + boolean success = true; + try + { + ByteBuf buf = channel.alloc().directBuffer(len); + buf.writerIndex(len); + channel.writeAndFlush(buf); + } + catch (Exception e) + { + success = false; + } + finally + { + channel.releaseOutbound(); + } + + return success; + } + + @Benchmark + public boolean chunkedBuf() + { + boolean success = true; + try + { + int chunkLen = len / subBufferCount; + + for (int i = 0; i < subBufferCount; i++) + { + ByteBuf buf = channel.alloc().directBuffer(chunkLen); + buf.writerIndex(chunkLen); + channel.write(buf); + } + channel.flush(); + } + catch (Exception e) + { + success = false; + } + finally + { + channel.releaseOutbound(); + } + + return success; + } +} diff --git a/test/microbench/org/apache/cassandra/test/microbench/ZeroCopyStreamingBenchmark.java b/test/microbench/org/apache/cassandra/test/microbench/ZeroCopyStreamingBenchmark.java index 3192bccddefd..72571e8dc98c 100644 --- a/test/microbench/org/apache/cassandra/test/microbench/ZeroCopyStreamingBenchmark.java +++ b/test/microbench/org/apache/cassandra/test/microbench/ZeroCopyStreamingBenchmark.java @@ -51,8 +51,8 @@ import org.apache.cassandra.io.sstable.SSTableMultiWriter; import org.apache.cassandra.io.sstable.format.SSTableReader; import org.apache.cassandra.locator.InetAddressAndPort; -import org.apache.cassandra.net.async.ByteBufDataOutputStreamPlus; -import org.apache.cassandra.net.async.RebufferingByteBufDataInputPlus; +import org.apache.cassandra.net.AsyncStreamingInputPlus; +import org.apache.cassandra.net.AsyncStreamingOutputPlus; import org.apache.cassandra.schema.CachingParams; import org.apache.cassandra.schema.KeyspaceParams; import org.apache.cassandra.streaming.DefaultConnectionFactory; @@ -122,7 +122,7 @@ public void setupBenchmark() throws IOException blockStreamWriter = new CassandraEntireSSTableStreamWriter(sstable, session, CassandraOutgoingFile.getComponentManifest(sstable)); CapturingNettyChannel blockStreamCaptureChannel = new CapturingNettyChannel(STREAM_SIZE); - ByteBufDataOutputStreamPlus out = ByteBufDataOutputStreamPlus.create(session, blockStreamCaptureChannel, 1024 * 1024); + AsyncStreamingOutputPlus out = new AsyncStreamingOutputPlus(blockStreamCaptureChannel); blockStreamWriter.write(out); serializedBlockStream = blockStreamCaptureChannel.getSerializedStream(); out.close(); @@ -152,7 +152,7 @@ public void setupBenchmark() throws IOException partialStreamWriter = new CassandraStreamWriter(sstable, sstable.getPositionsForRanges(requestedRanges), session); CapturingNettyChannel partialStreamChannel = new CapturingNettyChannel(STREAM_SIZE); - partialStreamWriter.write(ByteBufDataOutputStreamPlus.create(session, partialStreamChannel, 1024 * 1024)); + partialStreamWriter.write(new AsyncStreamingOutputPlus(partialStreamChannel)); serializedPartialStream = partialStreamChannel.getSerializedStream(); CassandraStreamHeader partialSSTableStreamHeader = @@ -200,7 +200,7 @@ private void generateData() .build() .applyUnsafe(); } - store.forceBlockingFlush(); + store.forceBlockingFlushToSSTable(); CompactionManager.instance.performMaximal(store, false); } @@ -230,7 +230,7 @@ private StreamSession setupStreamingSessionForTest() public void blockStreamWriter(BenchmarkState state) throws Exception { EmbeddedChannel channel = createMockNettyChannel(); - ByteBufDataOutputStreamPlus out = ByteBufDataOutputStreamPlus.create(state.session, channel, 1024 * 1024); + AsyncStreamingOutputPlus out = new AsyncStreamingOutputPlus(channel); state.blockStreamWriter.write(out); out.close(); channel.finishAndReleaseAll(); @@ -241,7 +241,7 @@ public void blockStreamWriter(BenchmarkState state) throws Exception public void blockStreamReader(BenchmarkState state) throws Exception { EmbeddedChannel channel = createMockNettyChannel(); - RebufferingByteBufDataInputPlus in = new RebufferingByteBufDataInputPlus(STREAM_SIZE, STREAM_SIZE, channel.config()); + AsyncStreamingInputPlus in = new AsyncStreamingInputPlus(channel); in.append(state.serializedBlockStream.retainedDuplicate()); SSTableMultiWriter sstableWriter = state.blockStreamReader.read(in); Collection newSstables = sstableWriter.finished(); @@ -254,7 +254,7 @@ public void blockStreamReader(BenchmarkState state) throws Exception public void partialStreamWriter(BenchmarkState state) throws Exception { EmbeddedChannel channel = createMockNettyChannel(); - ByteBufDataOutputStreamPlus out = ByteBufDataOutputStreamPlus.create(state.session, channel, 1024 * 1024); + AsyncStreamingOutputPlus out = new AsyncStreamingOutputPlus(channel); state.partialStreamWriter.write(out); out.close(); channel.finishAndReleaseAll(); @@ -265,7 +265,7 @@ public void partialStreamWriter(BenchmarkState state) throws Exception public void partialStreamReader(BenchmarkState state) throws Exception { EmbeddedChannel channel = createMockNettyChannel(); - RebufferingByteBufDataInputPlus in = new RebufferingByteBufDataInputPlus(STREAM_SIZE, STREAM_SIZE, channel.config()); + AsyncStreamingInputPlus in = new AsyncStreamingInputPlus(channel); in.append(state.serializedPartialStream.retainedDuplicate()); SSTableMultiWriter sstableWriter = state.partialStreamReader.read(in); Collection newSstables = sstableWriter.finished(); diff --git a/test/unit/org/apache/cassandra/SchemaLoader.java b/test/unit/org/apache/cassandra/SchemaLoader.java index 41f80954f872..2bc9bda60d80 100644 --- a/test/unit/org/apache/cassandra/SchemaLoader.java +++ b/test/unit/org/apache/cassandra/SchemaLoader.java @@ -473,6 +473,35 @@ public static TableMetadata.Builder compositeIndexCFMD(String ksName, String cfN return builder.indexes(indexes.build()); } + public static TableMetadata.Builder compositeMultipleIndexCFMD(String ksName, String cfName) throws ConfigurationException + { + TableMetadata.Builder builder = TableMetadata.builder(ksName, cfName) + .addPartitionKeyColumn("key", AsciiType.instance) + .addClusteringColumn("c1", AsciiType.instance) + .addRegularColumn("birthdate", LongType.instance) + .addRegularColumn("notbirthdate", LongType.instance) + .compression(getCompressionParameters()); + + + Indexes.Builder indexes = Indexes.builder(); + + indexes.add(IndexMetadata.fromIndexTargets(Collections.singletonList( + new IndexTarget(new ColumnIdentifier("birthdate", true), + IndexTarget.Type.VALUES)), + "birthdate_key_index", + IndexMetadata.Kind.COMPOSITES, + Collections.EMPTY_MAP)); + indexes.add(IndexMetadata.fromIndexTargets(Collections.singletonList( + new IndexTarget(new ColumnIdentifier("notbirthdate", true), + IndexTarget.Type.VALUES)), + "notbirthdate_key_index", + IndexMetadata.Kind.COMPOSITES, + Collections.EMPTY_MAP)); + + + return builder.indexes(indexes.build()); + } + public static TableMetadata.Builder keysIndexCFMD(String ksName, String cfName, boolean withIndex) { TableMetadata.Builder builder = diff --git a/test/unit/org/apache/cassandra/Util.java b/test/unit/org/apache/cassandra/Util.java index ba5d4d369764..0e4611a48aed 100644 --- a/test/unit/org/apache/cassandra/Util.java +++ b/test/unit/org/apache/cassandra/Util.java @@ -38,11 +38,11 @@ import com.google.common.collect.Iterators; import org.apache.commons.lang3.StringUtils; -import org.junit.Assert; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.apache.cassandra.db.compaction.ActiveCompactionsTracker; +import org.apache.cassandra.db.compaction.CompactionTasks; import org.apache.cassandra.db.lifecycle.LifecycleTransaction; import org.apache.cassandra.locator.InetAddressAndPort; import org.apache.cassandra.locator.ReplicaCollection; @@ -188,7 +188,7 @@ public static ColumnFamilyStore writeColumnFamily(List mutations) rm.applyUnsafe(); ColumnFamilyStore store = Keyspace.open(keyspaceName).getColumnFamilyStore(tableId); - store.forceBlockingFlush(); + store.forceBlockingFlushToSSTable(); return store; } @@ -247,9 +247,11 @@ public static Future compactAll(ColumnFamilyStore cfs, int gcBefore) public static void compact(ColumnFamilyStore cfs, Collection sstables) { int gcBefore = cfs.gcBefore(FBUtilities.nowInSeconds()); - List tasks = cfs.getCompactionStrategyManager().getUserDefinedTasks(sstables, gcBefore); - for (AbstractCompactionTask task : tasks) - task.execute(ActiveCompactionsTracker.NOOP); + try (CompactionTasks tasks = cfs.getCompactionStrategyManager().getUserDefinedTasks(sstables, gcBefore)) + { + for (AbstractCompactionTask task : tasks) + task.execute(ActiveCompactionsTracker.NOOP); + } } public static void expectEOF(Callable callable) @@ -707,6 +709,11 @@ public static Closeable markDirectoriesUnwriteable(ColumnFamilyStore cfs) } public static PagingState makeSomePagingState(ProtocolVersion protocolVersion) + { + return makeSomePagingState(protocolVersion, Integer.MAX_VALUE); + } + + public static PagingState makeSomePagingState(ProtocolVersion protocolVersion, int remainingInPartition) { TableMetadata metadata = TableMetadata.builder("ks", "tbl") @@ -722,7 +729,7 @@ public static PagingState makeSomePagingState(ProtocolVersion protocolVersion) Clustering c = Clustering.make(ByteBufferUtil.bytes("c1"), ByteBufferUtil.bytes(42)); Row row = BTreeRow.singleCellRow(c, BufferCell.live(def, 0, ByteBufferUtil.EMPTY_BYTE_BUFFER)); PagingState.RowMark mark = PagingState.RowMark.create(metadata, row, protocolVersion); - return new PagingState(pk, mark, 10, 0); + return new PagingState(pk, mark, 10, remainingInPartition); } public static void assertRCEquals(ReplicaCollection a, ReplicaCollection b) diff --git a/test/unit/org/apache/cassandra/audit/AuditLoggerTest.java b/test/unit/org/apache/cassandra/audit/AuditLoggerTest.java index b0299dc58be2..a44554729e65 100644 --- a/test/unit/org/apache/cassandra/audit/AuditLoggerTest.java +++ b/test/unit/org/apache/cassandra/audit/AuditLoggerTest.java @@ -17,6 +17,7 @@ */ package org.apache.cassandra.audit; +import org.junit.After; import org.junit.Assert; import org.junit.Before; import org.junit.BeforeClass; @@ -63,6 +64,12 @@ public void beforeTestMethod() enableAuditLogOptions(options); } + @After + public void afterTestMethod() + { + disableAuditLogOptions(); + } + private void enableAuditLogOptions(AuditLogOptions options) { String loggerName = "InMemoryAuditLogger"; @@ -89,7 +96,7 @@ public void testAuditLogFilters() throws Throwable execute("INSERT INTO %s (id, v1, v2) VALUES (?, ?, ?)", 2, "trace", "test"); AuditLogOptions options = new AuditLogOptions(); - options.excluded_keyspaces = KEYSPACE; + options.excluded_keyspaces += ',' + KEYSPACE; enableAuditLogOptions(options); String cql = "SELECT id, v1, v2 FROM " + KEYSPACE + '.' + currentTable() + " WHERE id = ?"; @@ -106,7 +113,7 @@ public void testAuditLogFilters() throws Throwable options = new AuditLogOptions(); options.included_keyspaces = KEYSPACE; - options.excluded_keyspaces = KEYSPACE; + options.excluded_keyspaces += ',' + KEYSPACE; enableAuditLogOptions(options); cql = "SELECT id, v1, v2 FROM " + KEYSPACE + '.' + currentTable() + " WHERE id = ?"; @@ -129,7 +136,7 @@ public void testAuditLogFiltersTransitions() throws Throwable execute("INSERT INTO %s (id, v1, v2) VALUES (?, ?, ?)", 2, "trace", "test"); AuditLogOptions options = new AuditLogOptions(); - options.excluded_keyspaces = KEYSPACE; + options.excluded_keyspaces += ',' + KEYSPACE; enableAuditLogOptions(options); String cql = "SELECT id, v1, v2 FROM " + KEYSPACE + '.' + currentTable() + " WHERE id = ?"; @@ -144,7 +151,7 @@ public void testAuditLogFiltersTransitions() throws Throwable options = new AuditLogOptions(); options.included_keyspaces = KEYSPACE; - options.excluded_keyspaces = KEYSPACE; + options.excluded_keyspaces += ',' + KEYSPACE; enableAuditLogOptions(options); cql = "SELECT id, v1, v2 FROM " + KEYSPACE + '.' + currentTable() + " WHERE id = ?"; @@ -162,11 +169,9 @@ public void testAuditLogFiltersTransitions() throws Throwable public void testAuditLogExceptions() { AuditLogOptions options = new AuditLogOptions(); - options.excluded_keyspaces = KEYSPACE; + options.excluded_keyspaces += ',' + KEYSPACE; enableAuditLogOptions(options); Assert.assertTrue(AuditLogManager.getInstance().isAuditingEnabled()); - - disableAuditLogOptions(); } @Test @@ -602,7 +607,7 @@ public void testIncludeSystemKeyspaces() throws Throwable { AuditLogOptions options = new AuditLogOptions(); options.included_categories = "QUERY,DML,PREPARE"; - options.excluded_keyspaces = "system_schema"; + options.excluded_keyspaces = "system_schema,system_virtual_schema"; enableAuditLogOptions(options); Session session = sessionNet(); @@ -620,7 +625,7 @@ public void testExcludeSystemKeyspaces() throws Throwable { AuditLogOptions options = new AuditLogOptions(); options.included_categories = "QUERY,DML,PREPARE"; - options.excluded_keyspaces = "system"; + options.excluded_keyspaces = "system,system_schema,system_virtual_schema"; enableAuditLogOptions(options); Session session = sessionNet(); diff --git a/test/unit/org/apache/cassandra/auth/AuthCacheTest.java b/test/unit/org/apache/cassandra/auth/AuthCacheTest.java index cc78ebc90a35..217821e8c7e4 100644 --- a/test/unit/org/apache/cassandra/auth/AuthCacheTest.java +++ b/test/unit/org/apache/cassandra/auth/AuthCacheTest.java @@ -1,4 +1,3 @@ - /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file @@ -18,120 +17,215 @@ */ package org.apache.cassandra.auth; -import org.junit.Assert; -import org.junit.BeforeClass; +import java.util.function.BooleanSupplier; +import java.util.function.Function; +import java.util.function.IntConsumer; +import java.util.function.IntSupplier; + import org.junit.Test; -import org.apache.cassandra.config.DatabaseDescriptor; +import org.apache.cassandra.db.ConsistencyLevel; +import org.apache.cassandra.exceptions.UnavailableException; -import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; public class AuthCacheTest { - private boolean loadFuncCalled = false; - private boolean isCacheEnabled = false; + private int loadCounter = 0; + private int validity = 2000; + private boolean isCacheEnabled = true; - @BeforeClass - public static void setup() + @Test + public void testCacheLoaderIsCalledOnFirst() { - DatabaseDescriptor.daemonInitialization(); + TestCache authCache = new TestCache<>(this::countingLoader, this::setValidity, () -> validity, () -> isCacheEnabled); + + int result = authCache.get("10"); + + assertEquals(10, result); + assertEquals(1, loadCounter); } @Test - public void testCaching() - { - AuthCache authCache = new AuthCache<>("TestCache", - DatabaseDescriptor::setCredentialsValidity, - DatabaseDescriptor::getCredentialsValidity, - DatabaseDescriptor::setCredentialsUpdateInterval, - DatabaseDescriptor::getCredentialsUpdateInterval, - DatabaseDescriptor::setCredentialsCacheMaxEntries, - DatabaseDescriptor::getCredentialsCacheMaxEntries, - this::load, - () -> true - ); - - // Test cacheloader is called if set - loadFuncCalled = false; - String result = authCache.get("test"); - assertTrue(loadFuncCalled); - Assert.assertEquals("load", result); - - // value should be fetched from cache - loadFuncCalled = false; - String result2 = authCache.get("test"); - assertFalse(loadFuncCalled); - Assert.assertEquals("load", result2); - - // value should be fetched from cache after complete invalidate + public void testCacheLoaderIsNotCalledOnSecond() + { + TestCache authCache = new TestCache<>(this::countingLoader, this::setValidity, () -> validity, () -> isCacheEnabled); + authCache.get("10"); + assertEquals(1, loadCounter); + + int result = authCache.get("10"); + + assertEquals(10, result); + assertEquals(1, loadCounter); + } + + @Test + public void testCacheLoaderIsAlwaysCalledWhenDisabled() + { + isCacheEnabled = false; + TestCache authCache = new TestCache<>(this::countingLoader, this::setValidity, () -> validity, () -> isCacheEnabled); + + authCache.get("10"); + int result = authCache.get("10"); + + assertEquals(10, result); + assertEquals(2, loadCounter); + } + + @Test + public void testCacheLoaderIsAlwaysCalledWhenValidityIsZero() + { + setValidity(0); + TestCache authCache = new TestCache<>(this::countingLoader, this::setValidity, () -> validity, () -> isCacheEnabled); + + authCache.get("10"); + int result = authCache.get("10"); + + assertEquals(10, result); + assertEquals(2, loadCounter); + } + + @Test + public void testCacheLoaderIsCalledAfterFullInvalidate() + { + TestCache authCache = new TestCache<>(this::countingLoader, this::setValidity, () -> validity, () -> isCacheEnabled); + authCache.get("10"); + authCache.invalidate(); - loadFuncCalled = false; - String result3 = authCache.get("test"); - assertTrue(loadFuncCalled); - Assert.assertEquals("load", result3); - - // value should be fetched from cache after invalidating key - authCache.invalidate("test"); - loadFuncCalled = false; - String result4 = authCache.get("test"); - assertTrue(loadFuncCalled); - Assert.assertEquals("load", result4); - - // set cache to null and load function should be called - loadFuncCalled = false; + int result = authCache.get("10"); + + assertEquals(10, result); + assertEquals(2, loadCounter); + } + + @Test + public void testCacheLoaderIsCalledAfterInvalidateKey() + { + TestCache authCache = new TestCache<>(this::countingLoader, this::setValidity, () -> validity, () -> isCacheEnabled); + authCache.get("10"); + + authCache.invalidate("10"); + int result = authCache.get("10"); + + assertEquals(10, result); + assertEquals(2, loadCounter); + } + + @Test + public void testCacheLoaderIsCalledAfterReset() + { + TestCache authCache = new TestCache<>(this::countingLoader, this::setValidity, () -> validity, () -> isCacheEnabled); + authCache.get("10"); + authCache.cache = null; - String result5 = authCache.get("test"); - assertTrue(loadFuncCalled); - Assert.assertEquals("load", result5); + int result = authCache.get("10"); + + assertEquals(10, result); + assertEquals(2, loadCounter); } @Test - public void testInitCache() - { - // Test that a validity of <= 0 will turn off caching - DatabaseDescriptor.setCredentialsValidity(0); - AuthCache authCache = new AuthCache<>("TestCache2", - DatabaseDescriptor::setCredentialsValidity, - DatabaseDescriptor::getCredentialsValidity, - DatabaseDescriptor::setCredentialsUpdateInterval, - DatabaseDescriptor::getCredentialsUpdateInterval, - DatabaseDescriptor::setCredentialsCacheMaxEntries, - DatabaseDescriptor::getCredentialsCacheMaxEntries, - this::load, - () -> true); + public void testThatZeroValidityTurnOffCaching() + { + setValidity(0); + TestCache authCache = new TestCache<>(this::countingLoader, this::setValidity, () -> validity, () -> isCacheEnabled); + authCache.get("10"); + int result = authCache.get("10"); + assertNull(authCache.cache); + assertEquals(10, result); + assertEquals(2, loadCounter); + } + + @Test + public void testThatRaisingValidityTurnOnCaching() + { + setValidity(0); + TestCache authCache = new TestCache<>(this::countingLoader, this::setValidity, () -> validity, () -> isCacheEnabled); + authCache.setValidity(2000); authCache.cache = authCache.initCache(null); + assertNotNull(authCache.cache); + } + + @Test + public void testDisableCache() + { + isCacheEnabled = false; + TestCache authCache = new TestCache<>(this::countingLoader, this::setValidity, () -> validity, () -> isCacheEnabled); - // Test enableCache works as intended - authCache = new AuthCache<>("TestCache3", - DatabaseDescriptor::setCredentialsValidity, - DatabaseDescriptor::getCredentialsValidity, - DatabaseDescriptor::setCredentialsUpdateInterval, - DatabaseDescriptor::getCredentialsUpdateInterval, - DatabaseDescriptor::setCredentialsCacheMaxEntries, - DatabaseDescriptor::getCredentialsCacheMaxEntries, - this::load, - () -> isCacheEnabled); assertNull(authCache.cache); + } + + @Test + public void testDynamicallyEnableCache() + { + isCacheEnabled = false; + TestCache authCache = new TestCache<>(this::countingLoader, this::setValidity, () -> validity, () -> isCacheEnabled); + isCacheEnabled = true; authCache.cache = authCache.initCache(null); + assertNotNull(authCache.cache); + } + + @Test + public void testDefaultPolicies() + { + TestCache authCache = new TestCache<>(this::countingLoader, this::setValidity, () -> validity, () -> isCacheEnabled); - // Ensure at a minimum these policies have been initialised by default assertTrue(authCache.cache.policy().expireAfterWrite().isPresent()); assertTrue(authCache.cache.policy().refreshAfterWrite().isPresent()); assertTrue(authCache.cache.policy().eviction().isPresent()); } - private String load(String test) + @Test(expected = UnavailableException.class) + public void testCassandraExceptionPassThroughWhenCacheEnabled() + { + TestCache cache = new TestCache<>(s -> { throw UnavailableException.create(ConsistencyLevel.QUORUM, 3, 1); }, this::setValidity, () -> validity, () -> isCacheEnabled); + + cache.get("expect-exception"); + } + + @Test(expected = UnavailableException.class) + public void testCassandraExceptionPassThroughWhenCacheDisable() { - loadFuncCalled = true; - return "load"; + isCacheEnabled = false; + TestCache cache = new TestCache<>(s -> { throw UnavailableException.create(ConsistencyLevel.QUORUM, 3, 1); }, this::setValidity, () -> validity, () -> isCacheEnabled); + + cache.get("expect-exception"); } + private void setValidity(int validity) + { + this.validity = validity; + } + + private Integer countingLoader(String s) + { + loadCounter++; + return Integer.parseInt(s); + } + + private static class TestCache extends AuthCache + { + private static int nameCounter = 0; // Allow us to create many instances of cache with same name prefix + + TestCache(Function loadFunction, IntConsumer setValidityDelegate, IntSupplier getValidityDelegate, BooleanSupplier cacheEnabledDelegate) + { + super("TestCache" + nameCounter++, + setValidityDelegate, + getValidityDelegate, + (updateInterval) -> {}, + () -> 1000, + (maxEntries) -> {}, + () -> 10, + loadFunction, + cacheEnabledDelegate); + } + } } diff --git a/test/unit/org/apache/cassandra/auth/CassandraNetworkAuthorizerTest.java b/test/unit/org/apache/cassandra/auth/CassandraNetworkAuthorizerTest.java index c24a769aed45..2e57173bb5cd 100644 --- a/test/unit/org/apache/cassandra/auth/CassandraNetworkAuthorizerTest.java +++ b/test/unit/org/apache/cassandra/auth/CassandraNetworkAuthorizerTest.java @@ -51,6 +51,7 @@ import static org.apache.cassandra.auth.AuthKeyspace.NETWORK_PERMISSIONS; import static org.apache.cassandra.auth.RoleTestUtils.LocalCassandraRoleManager; import static org.apache.cassandra.schema.SchemaConstants.AUTH_KEYSPACE_NAME; +import static org.apache.cassandra.auth.RoleTestUtils.getReadCount; public class CassandraNetworkAuthorizerTest { @@ -105,6 +106,8 @@ public static void defineSchema() throws ConfigurationException new LocalCassandraAuthorizer(), new LocalCassandraNetworkAuthorizer()); setupSuperUser(); + // not strictly necessary to init the cache here, but better to be explicit + Roles.initRolesCache(DatabaseDescriptor.getRoleManager(), () -> true); } @Before @@ -227,6 +230,8 @@ public void superUser() Assert.assertEquals(DCPermissions.subset("dc1"), dcPerms(username)); assertDcPermRow(username, "dc1"); + // clear the roles cache to lose the (non-)superuser status for the user + Roles.clearCache(); auth("ALTER ROLE %s WITH superuser = true", username); Assert.assertEquals(DCPermissions.all(), dcPerms(username)); } @@ -238,4 +243,16 @@ public void cantLogin() auth("CREATE ROLE %s", username); Assert.assertEquals(DCPermissions.none(), dcPerms(username)); } + + @Test + public void getLoginPrivilegeFromRolesCache() throws Exception + { + String username = createName(); + auth("CREATE ROLE %s", username); + long readCount = getReadCount(); + dcPerms(username); + Assert.assertEquals(++readCount, getReadCount()); + dcPerms(username); + Assert.assertEquals(readCount, getReadCount()); + } } diff --git a/test/unit/org/apache/cassandra/auth/PasswordAuthenticatorTest.java b/test/unit/org/apache/cassandra/auth/PasswordAuthenticatorTest.java index 37763d74de98..fd79b6aeade9 100644 --- a/test/unit/org/apache/cassandra/auth/PasswordAuthenticatorTest.java +++ b/test/unit/org/apache/cassandra/auth/PasswordAuthenticatorTest.java @@ -18,8 +18,22 @@ package org.apache.cassandra.auth; +import java.nio.charset.StandardCharsets; + +import com.google.common.collect.Iterables; +import org.junit.AfterClass; +import org.junit.BeforeClass; import org.junit.Test; +import com.datastax.driver.core.Authenticator; +import com.datastax.driver.core.PlainTextAuthProvider; +import org.apache.cassandra.SchemaLoader; +import org.apache.cassandra.cql3.CQLTester; +import org.apache.cassandra.exceptions.AuthenticationException; +import org.apache.cassandra.schema.KeyspaceParams; +import org.apache.cassandra.schema.SchemaConstants; +import org.apache.cassandra.schema.TableMetadata; + import static org.apache.cassandra.auth.CassandraRoleManager.*; import static org.apache.cassandra.auth.PasswordAuthenticator.*; import static org.junit.Assert.assertFalse; @@ -27,8 +41,11 @@ import static org.mindrot.jbcrypt.BCrypt.hashpw; import static org.mindrot.jbcrypt.BCrypt.gensalt; -public class PasswordAuthenticatorTest +public class PasswordAuthenticatorTest extends CQLTester { + + private static PasswordAuthenticator authenticator = new PasswordAuthenticator(); + @Test public void testCheckpw() throws Exception { @@ -61,4 +78,67 @@ public void testCheckpw() throws Exception assertFalse(checkpw(DEFAULT_SUPERUSER_PASSWORD, "$2$6$abcdefghijklmnopqrstuvABCDEFGHIJKLMNOPQRSTUVWXYZ01234")); assertFalse(checkpw(DEFAULT_SUPERUSER_PASSWORD, "$2a$6$abcdefghijklmnopqrstuvABCDEFGHIJKLMNOPQRSTUVWXYZ01234")); } + + @Test(expected = AuthenticationException.class) + public void testEmptyUsername() + { + testDecodeIllegalUserAndPwd("", "pwd"); + } + + @Test(expected = AuthenticationException.class) + public void testEmptyPassword() + { + testDecodeIllegalUserAndPwd("user", ""); + } + + @Test(expected = AuthenticationException.class) + public void testNULUsername0() + { + byte[] user = {'u', 's', PasswordAuthenticator.NUL, 'e', 'r'}; + testDecodeIllegalUserAndPwd(new String(user, StandardCharsets.UTF_8), "pwd"); + } + + @Test(expected = AuthenticationException.class) + public void testNULUsername1() + { + testDecodeIllegalUserAndPwd(new String(new byte[4]), "pwd"); + } + + @Test(expected = AuthenticationException.class) + public void testNULPassword0() + { + byte[] pwd = {'p', 'w', PasswordAuthenticator.NUL, 'd'}; + testDecodeIllegalUserAndPwd("user", new String(pwd, StandardCharsets.UTF_8)); + } + + @Test(expected = AuthenticationException.class) + public void testNULPassword1() + { + testDecodeIllegalUserAndPwd("user", new String(new byte[4])); + } + + private void testDecodeIllegalUserAndPwd(String username, String password) + { + SaslNegotiator negotiator = authenticator.newSaslNegotiator(null); + Authenticator clientAuthenticator = (new PlainTextAuthProvider(username, password)) + .newAuthenticator(null, null); + + negotiator.evaluateResponse(clientAuthenticator.initialResponse()); + negotiator.getAuthenticatedUser(); + } + + @BeforeClass + public static void setUp() + { + SchemaLoader.createKeyspace(SchemaConstants.AUTH_KEYSPACE_NAME, + KeyspaceParams.simple(1), + Iterables.toArray(AuthKeyspace.metadata().tables, TableMetadata.class)); + authenticator.setup(); + } + + @AfterClass + public static void tearDown() + { + schemaChange("DROP KEYSPACE " + SchemaConstants.AUTH_KEYSPACE_NAME); + } } \ No newline at end of file diff --git a/test/unit/org/apache/cassandra/batchlog/BatchlogManagerTest.java b/test/unit/org/apache/cassandra/batchlog/BatchlogManagerTest.java index 33fb209389d9..5ab7597e4a99 100644 --- a/test/unit/org/apache/cassandra/batchlog/BatchlogManagerTest.java +++ b/test/unit/org/apache/cassandra/batchlog/BatchlogManagerTest.java @@ -54,6 +54,7 @@ import org.apache.cassandra.utils.FBUtilities; import org.apache.cassandra.utils.UUIDGen; +import static java.util.concurrent.TimeUnit.MILLISECONDS; import static org.apache.cassandra.cql3.QueryProcessor.executeInternal; import static org.junit.Assert.*; @@ -155,7 +156,7 @@ public void testReplay() throws Exception } // Flush the batchlog to disk (see CASSANDRA-6822). - Keyspace.open(SchemaConstants.SYSTEM_KEYSPACE_NAME).getColumnFamilyStore(SystemKeyspace.BATCHES).forceBlockingFlush(); + Keyspace.open(SchemaConstants.SYSTEM_KEYSPACE_NAME).getColumnFamilyStore(SystemKeyspace.BATCHES).forceBlockingFlushToSSTable(); assertEquals(100, BatchlogManager.instance.countAllBatches() - initialAllBatches); assertEquals(0, BatchlogManager.instance.getTotalBatchesReplayed() - initialReplayedBatches); @@ -239,7 +240,7 @@ public void testTruncatedReplay() throws InterruptedException, ExecutionExceptio } // Flush the batchlog to disk (see CASSANDRA-6822). - Keyspace.open(SchemaConstants.SYSTEM_KEYSPACE_NAME).getColumnFamilyStore(SystemKeyspace.BATCHES).forceBlockingFlush(); + Keyspace.open(SchemaConstants.SYSTEM_KEYSPACE_NAME).getColumnFamilyStore(SystemKeyspace.BATCHES).forceBlockingFlushToSSTable(); // Force batchlog replay and wait for it to complete. BatchlogManager.instance.startBatchlogReplay().get(); @@ -277,7 +278,7 @@ public void testAddBatch() throws IOException long initialAllBatches = BatchlogManager.instance.countAllBatches(); TableMetadata cfm = Keyspace.open(KEYSPACE1).getColumnFamilyStore(CF_STANDARD5).metadata(); - long timestamp = (System.currentTimeMillis() - DatabaseDescriptor.getWriteRpcTimeout() * 2) * 1000; + long timestamp = (System.currentTimeMillis() - DatabaseDescriptor.getWriteRpcTimeout(MILLISECONDS) * 2) * 1000; UUID uuid = UUIDGen.getTimeUUID(); // Add a batch with 10 mutations @@ -309,7 +310,7 @@ public void testRemoveBatch() long initialAllBatches = BatchlogManager.instance.countAllBatches(); TableMetadata cfm = Keyspace.open(KEYSPACE1).getColumnFamilyStore(CF_STANDARD5).metadata(); - long timestamp = (System.currentTimeMillis() - DatabaseDescriptor.getWriteRpcTimeout() * 2) * 1000; + long timestamp = (System.currentTimeMillis() - DatabaseDescriptor.getWriteRpcTimeout(MILLISECONDS) * 2) * 1000; UUID uuid = UUIDGen.getTimeUUID(); // Add a batch with 10 mutations @@ -351,7 +352,7 @@ public void testReplayWithNoPeers() throws Exception TableMetadata cfm = Keyspace.open(KEYSPACE1).getColumnFamilyStore(CF_STANDARD1).metadata(); - long timestamp = (System.currentTimeMillis() - DatabaseDescriptor.getWriteRpcTimeout() * 2) * 1000; + long timestamp = (System.currentTimeMillis() - DatabaseDescriptor.getWriteRpcTimeout(MILLISECONDS) * 2) * 1000; UUID uuid = UUIDGen.getTimeUUID(); // Add a batch with 10 mutations @@ -367,7 +368,7 @@ public void testReplayWithNoPeers() throws Exception assertEquals(1, BatchlogManager.instance.countAllBatches() - initialAllBatches); // Flush the batchlog to disk (see CASSANDRA-6822). - Keyspace.open(SchemaConstants.SYSTEM_KEYSPACE_NAME).getColumnFamilyStore(SystemKeyspace.BATCHES).forceBlockingFlush(); + Keyspace.open(SchemaConstants.SYSTEM_KEYSPACE_NAME).getColumnFamilyStore(SystemKeyspace.BATCHES).forceBlockingFlushToSSTable(); assertEquals(1, BatchlogManager.instance.countAllBatches() - initialAllBatches); assertEquals(0, BatchlogManager.instance.getTotalBatchesReplayed() - initialReplayedBatches); diff --git a/test/unit/org/apache/cassandra/cache/AutoSavingCacheTest.java b/test/unit/org/apache/cassandra/cache/AutoSavingCacheTest.java index bb5129af9a46..7dd26b1478c0 100644 --- a/test/unit/org/apache/cassandra/cache/AutoSavingCacheTest.java +++ b/test/unit/org/apache/cassandra/cache/AutoSavingCacheTest.java @@ -74,7 +74,7 @@ private static void doTestSerializeAndLoadKeyCache() throws Exception RowUpdateBuilder rowBuilder = new RowUpdateBuilder(cfs.metadata(), System.currentTimeMillis(), "key1"); rowBuilder.add(colDef, "val1"); rowBuilder.build().apply(); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); } Assert.assertEquals(2, cfs.getLiveSSTables().size()); diff --git a/test/unit/org/apache/cassandra/concurrent/SEPExecutorTest.java b/test/unit/org/apache/cassandra/concurrent/SEPExecutorTest.java index 011a8bac1746..b6cf9bea7040 100644 --- a/test/unit/org/apache/cassandra/concurrent/SEPExecutorTest.java +++ b/test/unit/org/apache/cassandra/concurrent/SEPExecutorTest.java @@ -22,12 +22,15 @@ import java.io.PrintStream; import java.util.Arrays; import java.util.concurrent.ExecutorService; +import java.util.concurrent.TimeUnit; import org.junit.Assert; import org.junit.Test; import org.apache.cassandra.utils.FBUtilities; +import static java.util.concurrent.TimeUnit.MINUTES; + public class SEPExecutorTest { @Test @@ -56,7 +59,7 @@ public void write(int b) { } } // shutdown does not guarantee that threads are actually dead once it exits, only that they will stop promptly afterwards - sharedPool.shutdown(); + sharedPool.shutdownAndWait(1L, TimeUnit.MINUTES); for (Thread thread : Thread.getAllStackTraces().keySet()) { if (thread.getName().contains(MAGIC)) diff --git a/test/unit/org/apache/cassandra/cql3/CQLTester.java b/test/unit/org/apache/cassandra/cql3/CQLTester.java index 1eb56bcbb65f..b09fa0cfa74f 100644 --- a/test/unit/org/apache/cassandra/cql3/CQLTester.java +++ b/test/unit/org/apache/cassandra/cql3/CQLTester.java @@ -52,6 +52,7 @@ import org.apache.cassandra.locator.InetAddressAndPort; import org.apache.cassandra.locator.Replica; import org.apache.cassandra.locator.TokenMetadata; +import org.apache.cassandra.metrics.ClientMetrics; import org.apache.cassandra.schema.*; import org.apache.cassandra.config.DatabaseDescriptor; import org.apache.cassandra.cql3.functions.FunctionName; @@ -404,6 +405,7 @@ protected static void requireNetwork() throws ConfigurationException SchemaLoader.startGossiper(); server = new Server.Builder().withHost(nativeAddr).withPort(nativePort).build(); + ClientMetrics.instance.init(Collections.singleton(server)); server.start(); for (ProtocolVersion version : PROTOCOL_VERSIONS) @@ -472,7 +474,7 @@ public void flush(String keyspace) { ColumnFamilyStore store = getCurrentColumnFamilyStore(keyspace); if (store != null) - store.forceBlockingFlush(); + store.forceBlockingFlushToSSTable(); } public void disableCompaction(String keyspace) @@ -891,9 +893,14 @@ protected Session sessionNet(ProtocolVersion protocolVersion) return sessions.get(protocolVersion); } + protected SimpleClient newSimpleClient(ProtocolVersion version, boolean compression, boolean checksums, boolean isOverloadedException) throws IOException + { + return new SimpleClient(nativeAddr.getHostAddress(), nativePort, version, version.isBeta(), new EncryptionOptions()).connect(compression, checksums, isOverloadedException); + } + protected SimpleClient newSimpleClient(ProtocolVersion version, boolean compression, boolean checksums) throws IOException { - return new SimpleClient(nativeAddr.getHostAddress(), nativePort, version, version.isBeta(), new EncryptionOptions()).connect(compression, checksums); + return newSimpleClient(version, compression, checksums, false); } protected String formatQuery(String query) @@ -1690,6 +1697,49 @@ protected com.datastax.driver.core.TupleType tupleTypeOf(ProtocolVersion protoco return clusters.get(protocolVersion).getMetadata().newTupleType(types); } + /** + * Creates a default reference table with some pre-populated data, generating enough data to create a few commit log + * files. + */ + protected void populateReferenceData(boolean withCDC) throws Throwable + { + + String createString = "CREATE TABLE %s (a int, b int, c double, d decimal, e smallint, f tinyint, g blob, primary key (a, b))"; + if (withCDC) + createString += " WITH cdc=true"; + + // For each test, we start with the assumption of a populated set of a few files we can pull from. + createTable(createString); + + byte[] buffer = new byte[1024 * 256]; + CommitLog.instance.sync(true); + + // Populate some CommitLog segments on disk + writeReferenceLines(80, buffer); + CommitLog.instance.sync(true); + } + + protected void writeReferenceLines(int num, byte[] buffer) throws Throwable + { + for (int i = 0; i < num; i++) + writeReferenceDataLine(buffer); + } + + /** Broken out for access from tests that need to write incrementally more ref. data */ + protected void writeReferenceDataLine(byte[] buffer) throws Throwable + { + Random random = new Random(); + random.nextBytes(buffer); + execute("INSERT INTO %s (a, b, c, d, e, f, g) VALUES (?, ?, ?, ?, ?, ?, ?)", + random.nextInt(), + random.nextInt(), + random.nextDouble(), + random.nextLong(), + (short)random.nextInt(), + (byte)random.nextInt(), + ByteBuffer.wrap(buffer)); + } + // Attempt to find an AbstracType from a value (for serialization/printing sake). // Will work as long as we use types we know of, which is good enough for testing private static AbstractType typeFor(Object value) diff --git a/test/unit/org/apache/cassandra/cql3/GcCompactionTest.java b/test/unit/org/apache/cassandra/cql3/GcCompactionTest.java index 2fc07eb89db5..375ffcaa7a3e 100644 --- a/test/unit/org/apache/cassandra/cql3/GcCompactionTest.java +++ b/test/unit/org/apache/cassandra/cql3/GcCompactionTest.java @@ -346,11 +346,11 @@ public void testLocalDeletionTime() throws Throwable createTable("create table %s (k int, c1 int, primary key (k, c1)) with compaction = {'class': 'SizeTieredCompactionStrategy', 'provide_overlapping_tombstones':'row'}"); execute("delete from %s where k = 1"); Set readers = new HashSet<>(getCurrentColumnFamilyStore().getLiveSSTables()); - getCurrentColumnFamilyStore().forceBlockingFlush(); + getCurrentColumnFamilyStore().forceBlockingFlushToSSTable(); SSTableReader oldSSTable = getNewTable(readers); Thread.sleep(2000); execute("delete from %s where k = 1"); - getCurrentColumnFamilyStore().forceBlockingFlush(); + getCurrentColumnFamilyStore().forceBlockingFlushToSSTable(); SSTableReader newTable = getNewTable(readers); CompactionManager.instance.forceUserDefinedCompaction(oldSSTable.getFilename()); diff --git a/test/unit/org/apache/cassandra/cql3/KeyCacheCqlTest.java b/test/unit/org/apache/cassandra/cql3/KeyCacheCqlTest.java index b76cc784396c..1268bb002eb8 100644 --- a/test/unit/org/apache/cassandra/cql3/KeyCacheCqlTest.java +++ b/test/unit/org/apache/cassandra/cql3/KeyCacheCqlTest.java @@ -543,7 +543,7 @@ private void insertData(String table, String index, boolean withClustering) thro if (i % 10 == 9) { - Keyspace.open(KEYSPACE_PER_TEST).getColumnFamilyStore(table).forceFlush().get(); + Keyspace.open(KEYSPACE_PER_TEST).getColumnFamilyStore(table).forceFlushToSSTable().get(); if (index != null) triggerBlockingFlush(Keyspace.open(KEYSPACE_PER_TEST).getColumnFamilyStore(table).indexManager.getIndexByName(index)); } @@ -553,7 +553,7 @@ private void insertData(String table, String index, boolean withClustering) thro private static void prepareTable(String table) throws IOException, InterruptedException, java.util.concurrent.ExecutionException { StorageService.instance.disableAutoCompaction(KEYSPACE_PER_TEST, table); - Keyspace.open(KEYSPACE_PER_TEST).getColumnFamilyStore(table).forceFlush().get(); + Keyspace.open(KEYSPACE_PER_TEST).getColumnFamilyStore(table).forceFlushToSSTable().get(); Keyspace.open(KEYSPACE_PER_TEST).getColumnFamilyStore(table).truncateBlocking(); } diff --git a/test/unit/org/apache/cassandra/cql3/OutOfSpaceTest.java b/test/unit/org/apache/cassandra/cql3/OutOfSpaceTest.java index b4fe0f5fd2f3..cd86ce313da6 100644 --- a/test/unit/org/apache/cassandra/cql3/OutOfSpaceTest.java +++ b/test/unit/org/apache/cassandra/cql3/OutOfSpaceTest.java @@ -115,7 +115,7 @@ public void flushAndExpectError() throws InterruptedException, ExecutionExceptio { try { - Keyspace.open(KEYSPACE).getColumnFamilyStore(currentTable()).forceFlush().get(); + Keyspace.open(KEYSPACE).getColumnFamilyStore(currentTable()).forceFlushToSSTable().get(); fail("FSWriteError expected."); } catch (ExecutionException e) @@ -126,7 +126,7 @@ public void flushAndExpectError() throws InterruptedException, ExecutionExceptio // Make sure commit log wasn't discarded. TableId tableId = currentTableMetadata().id; - for (CommitLogSegment segment : CommitLog.instance.segmentManager.getActiveSegments()) + for (CommitLogSegment segment : CommitLog.instance.segmentManager.getSegmentsForUnflushedTables()) if (segment.getDirtyTableIds().contains(tableId)) return; fail("Expected commit log to remain dirty for the affected table."); diff --git a/test/unit/org/apache/cassandra/cql3/ViewComplexTest.java b/test/unit/org/apache/cassandra/cql3/ViewComplexTest.java index d24ab526385a..44c20d010b63 100644 --- a/test/unit/org/apache/cassandra/cql3/ViewComplexTest.java +++ b/test/unit/org/apache/cassandra/cql3/ViewComplexTest.java @@ -844,8 +844,8 @@ private void testExpiredLivenessLimit(boolean flush) throws Throwable } if (flush) { - ks.getColumnFamilyStore("mv1").forceBlockingFlush(); - ks.getColumnFamilyStore("mv2").forceBlockingFlush(); + ks.getColumnFamilyStore("mv1").forceBlockingFlushToSSTable(); + ks.getColumnFamilyStore("mv2").forceBlockingFlushToSSTable(); } for (String view : Arrays.asList("mv1", "mv2")) diff --git a/test/unit/org/apache/cassandra/cql3/ViewFilteringTest.java b/test/unit/org/apache/cassandra/cql3/ViewFilteringTest.java index 8b4a556b722a..73b50744299b 100644 --- a/test/unit/org/apache/cassandra/cql3/ViewFilteringTest.java +++ b/test/unit/org/apache/cassandra/cql3/ViewFilteringTest.java @@ -2121,7 +2121,7 @@ public void testOldTimestampsWithRestrictions() throws Throwable for (int i = 0; i < 100; i++) updateView("INSERT into %s (k,c,val)VALUES(?,?,?)", 0, i % 2, "baz"); - Keyspace.open(keyspace()).getColumnFamilyStore(currentTable()).forceBlockingFlush(); + Keyspace.open(keyspace()).getColumnFamilyStore(currentTable()).forceBlockingFlushToSSTable(); Assert.assertEquals(2, execute("select * from %s").size()); Assert.assertEquals(2, execute("select * from mv_tstest").size()); diff --git a/test/unit/org/apache/cassandra/cql3/ViewTest.java b/test/unit/org/apache/cassandra/cql3/ViewTest.java index 02fa19effa09..8647f4670982 100644 --- a/test/unit/org/apache/cassandra/cql3/ViewTest.java +++ b/test/unit/org/apache/cassandra/cql3/ViewTest.java @@ -128,7 +128,7 @@ public void testExistingRangeTombstone(boolean flush) throws Throwable updateView("DELETE FROM %s USING TIMESTAMP 10 WHERE k1 = 1 and c1=1"); if (flush) - Keyspace.open(keyspace()).getColumnFamilyStore(currentTable()).forceBlockingFlush(); + Keyspace.open(keyspace()).getColumnFamilyStore(currentTable()).forceBlockingFlushToSSTable(); String table = KEYSPACE + "." + currentTable(); updateView("BEGIN BATCH " + @@ -338,7 +338,7 @@ public void testOldTimestamps() throws Throwable for (int i = 0; i < 100; i++) updateView("INSERT into %s (k,c,val)VALUES(?,?,?)", 0, i % 2, "baz"); - Keyspace.open(keyspace()).getColumnFamilyStore(currentTable()).forceBlockingFlush(); + Keyspace.open(keyspace()).getColumnFamilyStore(currentTable()).forceBlockingFlushToSSTable(); Assert.assertEquals(2, execute("select * from %s").size()); Assert.assertEquals(2, execute("select * from mv_tstest").size()); @@ -926,7 +926,7 @@ public void testIgnoreUpdate() throws Throwable assertRows(execute("SELECT a, b, c from mv WHERE b = ?", 1), row(0, 1, null)); ColumnFamilyStore cfs = Keyspace.open(keyspace()).getColumnFamilyStore("mv"); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); Assert.assertEquals(1, cfs.getLiveSSTables().size()); } @@ -1350,22 +1350,22 @@ private void testViewBuilderResume(int concurrentViewBuilders) throws Throwable for (int i = 0; i < 1024; i++) execute("INSERT into %s (k,c,val)VALUES(?,?,?)", i, i, ""+i); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); for (int i = 0; i < 1024; i++) execute("INSERT into %s (k,c,val)VALUES(?,?,?)", i, i, ""+i); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); for (int i = 0; i < 1024; i++) execute("INSERT into %s (k,c,val)VALUES(?,?,?)", i, i, ""+i); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); for (int i = 0; i < 1024; i++) execute("INSERT into %s (k,c,val)VALUES(?,?,?)", i, i, ""+i); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); String viewName1 = "mv_test_" + concurrentViewBuilders; createView(viewName1, "CREATE MATERIALIZED VIEW %s AS SELECT * FROM %%s WHERE val IS NOT NULL AND k IS NOT NULL AND c IS NOT NULL PRIMARY KEY (val,k,c)"); diff --git a/test/unit/org/apache/cassandra/cql3/functions/OperationFctsTest.java b/test/unit/org/apache/cassandra/cql3/functions/OperationFctsTest.java index d27b746c1702..c8ee9352e944 100644 --- a/test/unit/org/apache/cassandra/cql3/functions/OperationFctsTest.java +++ b/test/unit/org/apache/cassandra/cql3/functions/OperationFctsTest.java @@ -177,35 +177,35 @@ public void testSingleOperations() throws Throwable row((short) 0, (short) 1, 1, 2L, 2.75F, 3.25, BigInteger.valueOf(3), new BigDecimal("4.25"))); assertRows(execute("SELECT a / c, b / c, c / c, d / c, e / c, f / c, g / c, h / c FROM %s WHERE a = 1 AND b = 2 AND c = 3 / 1"), - row(0, 0, 1, 1L, 1.8333334F, 2.1666666666666665, BigInteger.valueOf(2), new BigDecimal("2.833333333333333333333333333333333"))); + row(0, 0, 1, 1L, 1.8333334F, 2.1666666666666665, BigInteger.valueOf(2), new BigDecimal("2.83333333333333333333333333333333"))); assertRows(execute("SELECT a / d, b / d, c / d, d / d, e / d, f / d, g / d, h / d FROM %s WHERE a = 1 AND b = 2 AND c = 3 / 1"), row(0L, 0L, 0L, 1L, 1.375, 1.625, BigInteger.valueOf(1), new BigDecimal("2.125"))); assertRows(execute("SELECT a / e, b / e, c / e, d / e, e / e, f / e, g / e, h / e FROM %s WHERE a = 1 AND b = 2 AND c = 3 / 1"), - row(0.18181819F, 0.36363637F, 0.54545456F, 0.7272727272727273, 1.0F, 1.1818181818181819, new BigDecimal("1.272727272727272727272727272727273"), new BigDecimal("1.545454545454545454545454545454545"))); + row(0.18181819F, 0.36363637F, 0.54545456F, 0.7272727272727273, 1.0F, 1.1818181818181819, new BigDecimal("1.27272727272727272727272727272727"), new BigDecimal("1.54545454545454545454545454545455"))); assertRows(execute("SELECT a / f, b / f, c / f, d / f, e / f, f / f, g / f, h / f FROM %s WHERE a = 1 AND b = 2 AND c = 3 / 1"), - row(0.15384615384615385, 0.3076923076923077, 0.46153846153846156, 0.6153846153846154, 0.8461538461538461, 1.0, new BigDecimal("1.076923076923076923076923076923077"), new BigDecimal("1.307692307692307692307692307692308"))); + row(0.15384615384615385, 0.3076923076923077, 0.46153846153846156, 0.6153846153846154, 0.8461538461538461, 1.0, new BigDecimal("1.07692307692307692307692307692308"), new BigDecimal("1.30769230769230769230769230769231"))); assertRows(execute("SELECT a / g, b / g, c / g, d / g, e / g, f / g, g / g, h / g FROM %s WHERE a = 1 AND b = 2 AND c = 3 / 1"), row(BigInteger.valueOf(0), BigInteger.valueOf(0), BigInteger.valueOf(0), BigInteger.valueOf(0), - new BigDecimal("0.7857142857142857142857142857142857"), - new BigDecimal("0.9285714285714285714285714285714286"), + new BigDecimal("0.78571428571428571428571428571429"), + new BigDecimal("0.92857142857142857142857142857143"), BigInteger.valueOf(1), - new BigDecimal("1.214285714285714285714285714285714"))); + new BigDecimal("1.21428571428571428571428571428571"))); assertRows(execute("SELECT a / h, b / h, c / h, d / h, e / h, f / h, g / h, h / h FROM %s WHERE a = 1 AND b = 2 AND c = 3 / 1"), - row(new BigDecimal("0.1176470588235294117647058823529412"), - new BigDecimal("0.2352941176470588235294117647058824"), - new BigDecimal("0.3529411764705882352941176470588235"), - new BigDecimal("0.4705882352941176470588235294117647"), - new BigDecimal("0.6470588235294117647058823529411765"), - new BigDecimal("0.7647058823529411764705882352941176"), - new BigDecimal("0.8235294117647058823529411764705882"), + row(new BigDecimal("0.11764705882352941176470588235294"), + new BigDecimal("0.23529411764705882352941176470588"), + new BigDecimal("0.35294117647058823529411764705882"), + new BigDecimal("0.47058823529411764705882352941176"), + new BigDecimal("0.64705882352941176470588235294118"), + new BigDecimal("0.76470588235294117647058823529412"), + new BigDecimal("0.82352941176470588235294117647059"), new BigDecimal("1"))); // Test modulo operations @@ -265,6 +265,16 @@ public void testSingleOperations() throws Throwable row(null, null, null, null, null, null, null, null)); } + @Test + public void testModuloWithDecimals() throws Throwable + { + createTable("CREATE TABLE %s (numerator decimal, dec_mod decimal, int_mod int, bigint_mod bigint, PRIMARY KEY((numerator, dec_mod)))"); + execute("INSERT INTO %s (numerator, dec_mod, int_mod, bigint_mod) VALUES (123456789112345678921234567893123456, 2, 2, 2)"); + + assertRows(execute("SELECT numerator %% dec_mod, numerator %% int_mod, numerator %% bigint_mod from %s"), + row(new BigDecimal("0"), new BigDecimal("0.0"), new BigDecimal("0.0"))); + } + @Test public void testSingleOperationsWithLiterals() throws Throwable { @@ -438,7 +448,7 @@ public void testSingleOperationsWithLiterals() throws Throwable row(0, 1, 1, 2L, 2.75F, 3.25, BigInteger.valueOf(3), new BigDecimal("4.25"))); assertRows(execute("SELECT a / 3, b / 3, c / 3, d / 3, e / 3, f / 3, g / 3, h / 3 FROM %s WHERE a = 1 AND b = 2"), - row(0, 0, 1, 1L, 1.8333334F, 2.1666666666666665, BigInteger.valueOf(2), new BigDecimal("2.833333333333333333333333333333333"))); + row(0, 0, 1, 1L, 1.8333334F, 2.1666666666666665, BigInteger.valueOf(2), new BigDecimal("2.83333333333333333333333333333333"))); assertRows(execute("SELECT a / " + bigInt + "," + " b / " + bigInt + "," @@ -456,10 +466,10 @@ public void testSingleOperationsWithLiterals() throws Throwable BigInteger.valueOf(7).divide(BigInteger.valueOf(bigInt)))); assertRows(execute("SELECT a / 5.5, b / 5.5, c / 5.5, d / 5.5, e / 5.5, f / 5.5, g / 5.5, h / 5.5 FROM %s WHERE a = 1 AND b = 2"), - row(0.18181818181818182, 0.36363636363636365, 0.5454545454545454, 0.7272727272727273, 1.0, 1.1818181818181819, new BigDecimal("1.272727272727272727272727272727273"), new BigDecimal("1.545454545454545454545454545454545"))); + row(0.18181818181818182, 0.36363636363636365, 0.5454545454545454, 0.7272727272727273, 1.0, 1.1818181818181819, new BigDecimal("1.27272727272727272727272727272727"), new BigDecimal("1.54545454545454545454545454545455"))); assertRows(execute("SELECT a / 6.5, b / 6.5, c / 6.5, d / 6.5, e / 6.5, f / 6.5, g / 6.5, h / 6.5 FROM %s WHERE a = 1 AND b = 2"), - row(0.15384615384615385, 0.3076923076923077, 0.46153846153846156, 0.6153846153846154, 0.8461538461538461, 1.0, new BigDecimal("1.076923076923076923076923076923077"), new BigDecimal("1.307692307692307692307692307692308"))); + row(0.15384615384615385, 0.3076923076923077, 0.46153846153846156, 0.6153846153846154, 0.8461538461538461, 1.0, new BigDecimal("1.07692307692307692307692307692308"), new BigDecimal("1.30769230769230769230769230769231"))); // Test modulo operations @@ -502,6 +512,18 @@ public void testSingleOperationsWithLiterals() throws Throwable row((byte) 1, (short) 2, 2, 1, 4, 2, 0, -1)); } + @Test + public void testDivisionWithDecimals() throws Throwable + { + createTable("CREATE TABLE %s (numerator decimal, denominator decimal, PRIMARY KEY((numerator, denominator)))"); + execute("INSERT INTO %s (numerator, denominator) VALUES (8.5, 200000000000000000000000000000000000)"); + execute("INSERT INTO %s (numerator, denominator) VALUES (10000, 3)"); + + assertRows(execute("SELECT numerator / denominator from %s"), + row(new BigDecimal("0.0000000000000000000000000000000000425")), + row(new BigDecimal("3333.33333333333333333333333333333333"))); + } + @Test public void testWithCounters() throws Throwable { @@ -663,7 +685,7 @@ public void testWithDivisionByZero() throws Throwable OperationExecutionException.class, "SELECT g / a FROM %s WHERE a = 0 AND b = 2"); - assertInvalidThrowMessage("the operation 'decimal / tinyint' failed: Division by zero", + assertInvalidThrowMessage("the operation 'decimal / tinyint' failed: BigInteger divide by zero", OperationExecutionException.class, "SELECT h / a FROM %s WHERE a = 0 AND b = 2"); } diff --git a/test/unit/org/apache/cassandra/cql3/validation/miscellaneous/CrcCheckChanceTest.java b/test/unit/org/apache/cassandra/cql3/validation/miscellaneous/CrcCheckChanceTest.java index 246f512f66b5..dc7e6c0a955e 100644 --- a/test/unit/org/apache/cassandra/cql3/validation/miscellaneous/CrcCheckChanceTest.java +++ b/test/unit/org/apache/cassandra/cql3/validation/miscellaneous/CrcCheckChanceTest.java @@ -68,7 +68,7 @@ public void testChangingCrcCheckChance(boolean newFormat) throws Throwable ColumnFamilyStore cfs = Keyspace.open(CQLTester.KEYSPACE).getColumnFamilyStore(currentTable()); ColumnFamilyStore indexCfs = cfs.indexManager.getAllIndexColumnFamilyStores().iterator().next(); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); Assert.assertEquals(0.99, cfs.getCrcCheckChance(), 0.0); Assert.assertEquals(0.99, cfs.getLiveSSTables().iterator().next().getCrcCheckChance(), 0.0); @@ -96,19 +96,19 @@ public void testChangingCrcCheckChance(boolean newFormat) throws Throwable execute("INSERT INTO %s(p, c, v) values (?, ?, ?)", "p1", "k2", "v2"); execute("INSERT INTO %s(p, s) values (?, ?)", "p2", "sv2"); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); execute("INSERT INTO %s(p, c, v, s) values (?, ?, ?, ?)", "p1", "k1", "v1", "sv1"); execute("INSERT INTO %s(p, c, v) values (?, ?, ?)", "p1", "k2", "v2"); execute("INSERT INTO %s(p, s) values (?, ?)", "p2", "sv2"); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); execute("INSERT INTO %s(p, c, v, s) values (?, ?, ?, ?)", "p1", "k1", "v1", "sv1"); execute("INSERT INTO %s(p, c, v) values (?, ?, ?)", "p1", "k2", "v2"); execute("INSERT INTO %s(p, s) values (?, ?)", "p2", "sv2"); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); cfs.forceMajorCompaction(); //Now let's change via JMX @@ -182,7 +182,7 @@ public void testDropDuringCompaction() throws Throwable execute("INSERT INTO %s(p, c, v) values (?, ?, ?)", "p1", "k2", "v2"); execute("INSERT INTO %s(p, s) values (?, ?)", "p2", "sv2"); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); } DatabaseDescriptor.setCompactionThroughputMbPerSec(1); diff --git a/test/unit/org/apache/cassandra/cql3/validation/miscellaneous/SSTableMetadataTrackingTest.java b/test/unit/org/apache/cassandra/cql3/validation/miscellaneous/SSTableMetadataTrackingTest.java index 288cbe1a042e..0b9e71b138fb 100644 --- a/test/unit/org/apache/cassandra/cql3/validation/miscellaneous/SSTableMetadataTrackingTest.java +++ b/test/unit/org/apache/cassandra/cql3/validation/miscellaneous/SSTableMetadataTrackingTest.java @@ -33,7 +33,7 @@ public void baseCheck() throws Throwable createTable("CREATE TABLE %s (a int, b int, c text, PRIMARY KEY (a, b))"); ColumnFamilyStore cfs = Keyspace.open(keyspace()).getColumnFamilyStore(currentTable()); execute("INSERT INTO %s (a,b,c) VALUES (1,1,'1') using timestamp 9999"); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); StatsMetadata metadata = cfs.getLiveSSTables().iterator().next().getSSTableMetadata(); assertEquals(9999, metadata.minTimestamp); assertEquals(Integer.MAX_VALUE, metadata.maxLocalDeletionTime); @@ -50,7 +50,7 @@ public void testMinMaxtimestampRange() throws Throwable ColumnFamilyStore cfs = Keyspace.open(keyspace()).getColumnFamilyStore(currentTable()); execute("INSERT INTO %s (a,b,c) VALUES (1,1,'1') using timestamp 10000"); execute("DELETE FROM %s USING TIMESTAMP 9999 WHERE a = 1 and b = 1"); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); StatsMetadata metadata = cfs.getLiveSSTables().iterator().next().getSSTableMetadata(); assertEquals(9999, metadata.minTimestamp); assertEquals(10000, metadata.maxTimestamp); @@ -69,7 +69,7 @@ public void testMinMaxtimestampRow() throws Throwable ColumnFamilyStore cfs = Keyspace.open(keyspace()).getColumnFamilyStore(currentTable()); execute("INSERT INTO %s (a,b,c) VALUES (1,1,'1') using timestamp 10000"); execute("DELETE FROM %s USING TIMESTAMP 9999 WHERE a = 1"); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); StatsMetadata metadata = cfs.getLiveSSTables().iterator().next().getSSTableMetadata(); assertEquals(9999, metadata.minTimestamp); assertEquals(10000, metadata.maxTimestamp); @@ -88,7 +88,7 @@ public void testTrackMetadata_rangeTombstone() throws Throwable createTable("CREATE TABLE %s (a int, b int, c text, PRIMARY KEY (a, b)) WITH gc_grace_seconds = 10000"); ColumnFamilyStore cfs = Keyspace.open(keyspace()).getColumnFamilyStore(currentTable()); execute("DELETE FROM %s USING TIMESTAMP 9999 WHERE a = 1 and b = 1"); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); assertEquals(1, cfs.getLiveSSTables().size()); StatsMetadata metadata = cfs.getLiveSSTables().iterator().next().getSSTableMetadata(); assertEquals(9999, metadata.minTimestamp); @@ -108,7 +108,7 @@ public void testTrackMetadata_rowTombstone() throws Throwable ColumnFamilyStore cfs = Keyspace.open(keyspace()).getColumnFamilyStore(currentTable()); execute("DELETE FROM %s USING TIMESTAMP 9999 WHERE a = 1"); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); assertEquals(1, cfs.getLiveSSTables().size()); StatsMetadata metadata = cfs.getLiveSSTables().iterator().next().getSSTableMetadata(); assertEquals(9999, metadata.minTimestamp); @@ -128,7 +128,7 @@ public void testTrackMetadata_rowMarker() throws Throwable ColumnFamilyStore cfs = Keyspace.open(keyspace()).getColumnFamilyStore(currentTable()); execute("INSERT INTO %s (a) VALUES (1) USING TIMESTAMP 9999"); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); assertEquals(1, cfs.getLiveSSTables().size()); StatsMetadata metadata = cfs.getLiveSSTables().iterator().next().getSSTableMetadata(); assertEquals(9999, metadata.minTimestamp); @@ -147,7 +147,7 @@ public void testTrackMetadata_rowMarkerDelete() throws Throwable createTable("CREATE TABLE %s (a int, PRIMARY KEY (a))"); ColumnFamilyStore cfs = Keyspace.open(keyspace()).getColumnFamilyStore(currentTable()); execute("DELETE FROM %s USING TIMESTAMP 9999 WHERE a=1"); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); assertEquals(1, cfs.getLiveSSTables().size()); StatsMetadata metadata = cfs.getLiveSSTables().iterator().next().getSSTableMetadata(); assertEquals(9999, metadata.minTimestamp); diff --git a/test/unit/org/apache/cassandra/db/CleanupTest.java b/test/unit/org/apache/cassandra/db/CleanupTest.java index 46c0afd938ca..a91a723dd24e 100644 --- a/test/unit/org/apache/cassandra/db/CleanupTest.java +++ b/test/unit/org/apache/cassandra/db/CleanupTest.java @@ -27,10 +27,12 @@ import java.util.LinkedList; import java.util.List; import java.util.Map; +import java.util.Set; import java.util.UUID; import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; +import com.google.common.collect.Sets; import org.junit.BeforeClass; import org.junit.Test; @@ -68,6 +70,9 @@ public class CleanupTest public static final String CF_INDEXED2 = "Indexed2"; public static final String CF_STANDARD2 = "Standard2"; + public static final String KEYSPACE3 = "CleanupSkipSSTables"; + public static final String CF_STANDARD3 = "Standard3"; + public static final ByteBuffer COLUMN = ByteBufferUtil.bytes("birthdate"); public static final ByteBuffer VALUE = ByteBuffer.allocate(8); static @@ -105,6 +110,9 @@ public String getDatacenter(InetAddressAndPort endpoint) KeyspaceParams.nts("DC1", 1), SchemaLoader.standardCFMD(KEYSPACE2, CF_STANDARD2), SchemaLoader.compositeIndexCFMD(KEYSPACE2, CF_INDEXED2, true)); + SchemaLoader.createKeyspace(KEYSPACE3, + KeyspaceParams.nts("DC1", 1), + SchemaLoader.standardCFMD(KEYSPACE3, CF_STANDARD3)); } @Test @@ -247,6 +255,43 @@ private void testCleanupWithNoTokenRange(boolean isUserDefined) throws Exception assertTrue(cfs.getLiveSSTables().isEmpty()); } + @Test + public void testCleanupSkippingSSTables() throws UnknownHostException, ExecutionException, InterruptedException + { + Keyspace keyspace = Keyspace.open(KEYSPACE3); + ColumnFamilyStore cfs = keyspace.getColumnFamilyStore(CF_STANDARD3); + cfs.disableAutoCompaction(); + TokenMetadata tmd = StorageService.instance.getTokenMetadata(); + tmd.clearUnsafe(); + tmd.updateNormalToken(token(new byte[]{ 50 }), InetAddressAndPort.getByName("127.0.0.1")); + + for (byte i = 0; i < 100; i++) + { + new RowUpdateBuilder(cfs.metadata(), System.currentTimeMillis(), ByteBuffer.wrap(new byte[]{ i })) + .clustering(COLUMN) + .add("val", VALUE) + .build() + .applyUnsafe(); + cfs.forceBlockingFlushToSSTable(); + } + + Set beforeFirstCleanup = Sets.newHashSet(cfs.getLiveSSTables()); + // single token - 127.0.0.1 owns everything, cleanup should be noop + cfs.forceCleanup(2); + assertEquals(beforeFirstCleanup, cfs.getLiveSSTables()); + tmd.updateNormalToken(token(new byte[]{ 120 }), InetAddressAndPort.getByName("127.0.0.2")); + + cfs.forceCleanup(2); + for (SSTableReader sstable : cfs.getLiveSSTables()) + { + assertEquals(sstable.first, sstable.last); // single-token sstables + assertTrue(sstable.first.getToken().compareTo(token(new byte[]{ 50 })) <= 0); + // with single-token sstables they should all either be skipped or dropped: + assertTrue(beforeFirstCleanup.contains(sstable)); + } + + } + @Test public void testuserDefinedCleanupWithNewToken() throws ExecutionException, InterruptedException, UnknownHostException @@ -362,7 +407,7 @@ protected void fillCF(ColumnFamilyStore cfs, String colName, int rowsPerSSTable) .applyUnsafe(); } - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); } protected List getMaxTimestampList(ColumnFamilyStore cfs) diff --git a/test/unit/org/apache/cassandra/db/CleanupTransientTest.java b/test/unit/org/apache/cassandra/db/CleanupTransientTest.java index 9789183dc14b..dbce4e1be61e 100644 --- a/test/unit/org/apache/cassandra/db/CleanupTransientTest.java +++ b/test/unit/org/apache/cassandra/db/CleanupTransientTest.java @@ -182,7 +182,7 @@ protected void fillCF(ColumnFamilyStore cfs, String colName, int rowsPerSSTable) .applyUnsafe(); } - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); } protected List getMaxTimestampList(ColumnFamilyStore cfs) diff --git a/test/unit/org/apache/cassandra/db/ColumnFamilyMetricTest.java b/test/unit/org/apache/cassandra/db/ColumnFamilyMetricTest.java index c016f9ba6446..aeaad46d2d79 100644 --- a/test/unit/org/apache/cassandra/db/ColumnFamilyMetricTest.java +++ b/test/unit/org/apache/cassandra/db/ColumnFamilyMetricTest.java @@ -61,7 +61,7 @@ public void testSizeMetric() .build() .applyUnsafe(); } - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); Collection sstables = cfs.getLiveSSTables(); long size = 0; for (SSTableReader reader : sstables) diff --git a/test/unit/org/apache/cassandra/db/ColumnFamilyStoreTest.java b/test/unit/org/apache/cassandra/db/ColumnFamilyStoreTest.java index 888cdc6c8e75..1b8f98646e37 100644 --- a/test/unit/org/apache/cassandra/db/ColumnFamilyStoreTest.java +++ b/test/unit/org/apache/cassandra/db/ColumnFamilyStoreTest.java @@ -38,7 +38,6 @@ import org.apache.cassandra.db.lifecycle.SSTableSet; import org.apache.cassandra.db.rows.*; import org.apache.cassandra.db.partitions.*; -import org.apache.cassandra.db.marshal.*; import org.apache.cassandra.exceptions.ConfigurationException; import org.apache.cassandra.io.sstable.Component; import org.apache.cassandra.io.sstable.Descriptor; @@ -101,14 +100,14 @@ public void testTimeSortedQuery() .add("val", "asdf") .build() .applyUnsafe(); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); new RowUpdateBuilder(cfs.metadata(), 1, "key1") .clustering("Column1") .add("val", "asdf") .build() .applyUnsafe(); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); ((ClearableHistogram)cfs.metric.sstablesPerReadHistogram.cf).clear(); // resets counts Util.getAll(Util.cmd(cfs, "key1").includeRow("c1").build()); @@ -177,7 +176,7 @@ public void testDeleteStandardRowSticksAfterFlush() throws Throwable assertRangeCount(cfs, col, val, 2); // flush. - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); // insert, don't flush new RowUpdateBuilder(cfs.metadata(), 1, "key3").clustering("Column1").add("val", "val1").build().applyUnsafe(); @@ -192,7 +191,7 @@ public void testDeleteStandardRowSticksAfterFlush() throws Throwable assertRangeCount(cfs, col, val, 2); // flush - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); // re-verify delete. // first breakage is right here because of CASSANDRA-1837. assertRangeCount(cfs, col, val, 2); @@ -210,7 +209,7 @@ public void testDeleteStandardRowSticksAfterFlush() throws Throwable assertRangeCount(cfs, col, val, 4); // and it remains so after flush. (this wasn't failing before, but it's good to check.) - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); assertRangeCount(cfs, col, val, 4); } @@ -259,9 +258,9 @@ public void testBackupAfterFlush() throws Throwable { ColumnFamilyStore cfs = Keyspace.open(KEYSPACE2).getColumnFamilyStore(CF_STANDARD1); new RowUpdateBuilder(cfs.metadata(), 0, ByteBufferUtil.bytes("key1")).clustering("Column1").add("val", "asdf").build().applyUnsafe(); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); new RowUpdateBuilder(cfs.metadata(), 0, ByteBufferUtil.bytes("key2")).clustering("Column1").add("val", "asdf").build().applyUnsafe(); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); for (int version = 1; version <= 2; ++version) { @@ -400,7 +399,7 @@ public void testBackupAfterFlush() throws Throwable public void reTest(ColumnFamilyStore cfs, Runnable verify) throws Exception { verify.run(); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); verify.run(); } @@ -435,7 +434,7 @@ public void testScrubDataDirectories() throws Throwable ColumnFamilyStore.scrubDataDirectories(cfs.metadata()); new RowUpdateBuilder(cfs.metadata(), 2, "key").clustering("name").add("val", "2").build().applyUnsafe(); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); // Nuke the metadata and reload that sstable Collection ssTables = cfs.getLiveSSTables(); diff --git a/test/unit/org/apache/cassandra/db/DeletePartitionTest.java b/test/unit/org/apache/cassandra/db/DeletePartitionTest.java index 6ed43f726250..f8b31470b6eb 100644 --- a/test/unit/org/apache/cassandra/db/DeletePartitionTest.java +++ b/test/unit/org/apache/cassandra/db/DeletePartitionTest.java @@ -75,7 +75,7 @@ public void testDeletePartition(DecoratedKey key, boolean flushBeforeRemove, boo assertTrue(r.getCell(column).value().equals(ByteBufferUtil.bytes("asdf"))); if (flushBeforeRemove) - store.forceBlockingFlush(); + store.forceBlockingFlushToSSTable(); // delete the partition new Mutation.PartitionUpdateCollector(KEYSPACE1, key) @@ -84,7 +84,7 @@ public void testDeletePartition(DecoratedKey key, boolean flushBeforeRemove, boo .applyUnsafe(); if (flushAfterRemove) - store.forceBlockingFlush(); + store.forceBlockingFlushToSSTable(); // validate removal ImmutableBTreePartition partitionUnfiltered = Util.getOnlyPartitionUnfiltered(Util.cmd(store, key).build()); diff --git a/test/unit/org/apache/cassandra/db/ImportTest.java b/test/unit/org/apache/cassandra/db/ImportTest.java index 5ceb233a56e4..a1d04ab30a8b 100644 --- a/test/unit/org/apache/cassandra/db/ImportTest.java +++ b/test/unit/org/apache/cassandra/db/ImportTest.java @@ -23,18 +23,14 @@ import java.io.RandomAccessFile; import java.nio.file.Files; import java.nio.file.Path; -import java.nio.file.Paths; -import java.util.Arrays; import java.util.Collections; import java.util.HashSet; import java.util.Iterator; import java.util.List; -import java.util.Random; import java.util.Set; import com.google.common.collect.Lists; import com.google.common.collect.Sets; -import org.apache.commons.io.FileUtils; import org.apache.commons.lang3.StringUtils; import org.junit.Test; @@ -43,8 +39,6 @@ import org.apache.cassandra.cql3.UntypedResultSet; import org.apache.cassandra.db.lifecycle.LifecycleTransaction; import org.apache.cassandra.dht.BootStrapper; -import org.apache.cassandra.dht.Murmur3Partitioner; -import org.apache.cassandra.dht.Token; import org.apache.cassandra.io.sstable.Component; import org.apache.cassandra.io.sstable.format.SSTableReader; import org.apache.cassandra.locator.InetAddressAndPort; @@ -66,7 +60,7 @@ public void basicImportTest() throws Throwable createTable("create table %s (id int primary key, d int)"); for (int i = 0; i < 10; i++) execute("insert into %s (id, d) values (?, ?)", i, i); - getCurrentColumnFamilyStore().forceBlockingFlush(); + getCurrentColumnFamilyStore().forceBlockingFlushToSSTable(); Set sstables = getCurrentColumnFamilyStore().getLiveSSTables(); getCurrentColumnFamilyStore().clearUnsafe(); @@ -87,14 +81,14 @@ public void basicImportMultiDirTest() throws Throwable createTable("create table %s (id int primary key, d int)"); for (int i = 0; i < 10; i++) execute("insert into %s (id, d) values (?, ?)", i, i); - getCurrentColumnFamilyStore().forceBlockingFlush(); + getCurrentColumnFamilyStore().forceBlockingFlushToSSTable(); Set sstables = getCurrentColumnFamilyStore().getLiveSSTables(); getCurrentColumnFamilyStore().clearUnsafe(); File backupdir = moveToBackupDir(sstables); for (int i = 10; i < 20; i++) execute("insert into %s (id, d) values (?, ?)", i, i); - getCurrentColumnFamilyStore().forceBlockingFlush(); + getCurrentColumnFamilyStore().forceBlockingFlushToSSTable(); sstables = getCurrentColumnFamilyStore().getLiveSSTables(); getCurrentColumnFamilyStore().clearUnsafe(); @@ -118,7 +112,7 @@ public void refreshTest() throws Throwable createTable("create table %s (id int primary key, d int)"); for (int i = 0; i < 10; i++) execute("insert into %s (id, d) values (?, ?)", i, i); - getCurrentColumnFamilyStore().forceBlockingFlush(); + getCurrentColumnFamilyStore().forceBlockingFlushToSSTable(); Set sstables = getCurrentColumnFamilyStore().getLiveSSTables(); getCurrentColumnFamilyStore().clearUnsafe(); sstables.forEach(s -> s.selfRef().release()); @@ -133,7 +127,7 @@ public void importResetLevelTest() throws Throwable createTable("create table %s (id int primary key, d int)"); for (int i = 0; i < 10; i++) execute("insert into %s (id, d) values (?, ?)", i, i); - getCurrentColumnFamilyStore().forceBlockingFlush(); + getCurrentColumnFamilyStore().forceBlockingFlushToSSTable(); Set sstables = getCurrentColumnFamilyStore().getLiveSSTables(); getCurrentColumnFamilyStore().clearUnsafe(); for (SSTableReader sstable : sstables) @@ -170,7 +164,7 @@ public void importClearRepairedTest() throws Throwable createTable("create table %s (id int primary key, d int)"); for (int i = 0; i < 10; i++) execute("insert into %s (id, d) values (?, ?)", i, i); - getCurrentColumnFamilyStore().forceBlockingFlush(); + getCurrentColumnFamilyStore().forceBlockingFlushToSSTable(); Set sstables = getCurrentColumnFamilyStore().getLiveSSTables(); getCurrentColumnFamilyStore().clearUnsafe(); for (SSTableReader sstable : sstables) @@ -252,7 +246,7 @@ public void testGetCorrectDirectory() throws Throwable for (int i = 0; i < 10; i++) { execute("insert into %s (id, d) values (?, ?)", i, i); - getCurrentColumnFamilyStore().forceBlockingFlush(); + getCurrentColumnFamilyStore().forceBlockingFlushToSSTable(); } Set toMove = getCurrentColumnFamilyStore().getLiveSSTables(); @@ -281,11 +275,11 @@ private void testCorruptHelper(boolean verify) throws Throwable createTable("create table %s (id int primary key, d int)"); for (int i = 0; i < 10; i++) execute("insert into %s (id, d) values (?, ?)", i, i); - getCurrentColumnFamilyStore().forceBlockingFlush(); + getCurrentColumnFamilyStore().forceBlockingFlushToSSTable(); SSTableReader sstableToCorrupt = getCurrentColumnFamilyStore().getLiveSSTables().iterator().next(); for (int i = 0; i < 10; i++) execute("insert into %s (id, d) values (?, ?)", i + 10, i); - getCurrentColumnFamilyStore().forceBlockingFlush(); + getCurrentColumnFamilyStore().forceBlockingFlushToSSTable(); Set sstables = getCurrentColumnFamilyStore().getLiveSSTables(); getCurrentColumnFamilyStore().clearUnsafe(); @@ -302,7 +296,7 @@ private void testCorruptHelper(boolean verify) throws Throwable // now move a correct sstable to another directory to make sure that directory gets properly imported for (int i = 100; i < 130; i++) execute("insert into %s (id, d) values (?, ?)", i, i); - getCurrentColumnFamilyStore().forceBlockingFlush(); + getCurrentColumnFamilyStore().forceBlockingFlushToSSTable(); Set correctSSTables = getCurrentColumnFamilyStore().getLiveSSTables(); getCurrentColumnFamilyStore().clearUnsafe(); @@ -359,7 +353,7 @@ public void testImportOutOfRange() throws Throwable createTable("create table %s (id int primary key, d int)"); for (int i = 0; i < 1000; i++) execute("insert into %s (id, d) values (?, ?)", i, i); - getCurrentColumnFamilyStore().forceBlockingFlush(); + getCurrentColumnFamilyStore().forceBlockingFlushToSSTable(); Set sstables = getCurrentColumnFamilyStore().getLiveSSTables(); getCurrentColumnFamilyStore().clearUnsafe(); @@ -404,7 +398,7 @@ public void testImportOutOfRangeExtendedVerify() throws Throwable createTable("create table %s (id int primary key, d int)"); for (int i = 0; i < 1000; i++) execute("insert into %s (id, d) values (?, ?)", i, i); - getCurrentColumnFamilyStore().forceBlockingFlush(); + getCurrentColumnFamilyStore().forceBlockingFlushToSSTable(); Set sstables = getCurrentColumnFamilyStore().getLiveSSTables(); getCurrentColumnFamilyStore().clearUnsafe(); @@ -440,7 +434,7 @@ public void testImportInvalidateCache() throws Throwable createTable("create table %s (id int primary key, d int) WITH caching = { 'keys': 'NONE', 'rows_per_partition': 'ALL' }"); for (int i = 0; i < 10; i++) execute("insert into %s (id, d) values (?, ?)", i, i); - getCurrentColumnFamilyStore().forceBlockingFlush(); + getCurrentColumnFamilyStore().forceBlockingFlushToSSTable(); CacheService.instance.setRowCacheCapacityInMB(1); Set keysToInvalidate = new HashSet<>(); @@ -461,7 +455,7 @@ public void testImportInvalidateCache() throws Throwable for (int i = 10; i < 20; i++) execute("insert into %s (id, d) values (?, ?)", i, i); - getCurrentColumnFamilyStore().forceBlockingFlush(); + getCurrentColumnFamilyStore().forceBlockingFlushToSSTable(); Set allCachedKeys = new HashSet<>(); @@ -508,7 +502,7 @@ public void testImportCacheEnabledWithoutSrcDir() throws Throwable createTable("create table %s (id int primary key, d int) WITH caching = { 'keys': 'NONE', 'rows_per_partition': 'ALL' }"); for (int i = 0; i < 10; i++) execute("insert into %s (id, d) values (?, ?)", i, i); - getCurrentColumnFamilyStore().forceBlockingFlush(); + getCurrentColumnFamilyStore().forceBlockingFlushToSSTable(); Set sstables = getCurrentColumnFamilyStore().getLiveSSTables(); CacheService.instance.setRowCacheCapacityInMB(1); getCurrentColumnFamilyStore().clearUnsafe(); @@ -525,7 +519,7 @@ public void testRefreshCorrupt() throws Throwable createTable("create table %s (id int primary key, d int) WITH caching = { 'keys': 'NONE', 'rows_per_partition': 'ALL' }"); for (int i = 0; i < 10; i++) execute("insert into %s (id, d) values (?, ?)", i, i); - getCurrentColumnFamilyStore().forceBlockingFlush(); + getCurrentColumnFamilyStore().forceBlockingFlushToSSTable(); Set sstables = getCurrentColumnFamilyStore().getLiveSSTables(); getCurrentColumnFamilyStore().clearUnsafe(); sstables.forEach(s -> s.selfRef().release()); @@ -540,10 +534,10 @@ public void testRefreshCorrupt() throws Throwable for (int i = 10; i < 20; i++) execute("insert into %s (id, d) values (?, ?)", i, i); - getCurrentColumnFamilyStore().forceBlockingFlush(); + getCurrentColumnFamilyStore().forceBlockingFlushToSSTable(); for (int i = 20; i < 30; i++) execute("insert into %s (id, d) values (?, ?)", i, i); - getCurrentColumnFamilyStore().forceBlockingFlush(); + getCurrentColumnFamilyStore().forceBlockingFlushToSSTable(); Set expectedFiles = new HashSet<>(getCurrentColumnFamilyStore().getLiveSSTables()); @@ -589,14 +583,14 @@ public void importBadDirectoryTest() throws Throwable createTable("create table %s (id int primary key, d int)"); for (int i = 0; i < 10; i++) execute("insert into %s (id, d) values (?, ?)", i, i); - getCurrentColumnFamilyStore().forceBlockingFlush(); + getCurrentColumnFamilyStore().forceBlockingFlushToSSTable(); Set sstables = getCurrentColumnFamilyStore().getLiveSSTables(); getCurrentColumnFamilyStore().clearUnsafe(); File backupdir = moveToBackupDir(sstables); for (int i = 10; i < 20; i++) execute("insert into %s (id, d) values (?, ?)", i, i); - getCurrentColumnFamilyStore().forceBlockingFlush(); + getCurrentColumnFamilyStore().forceBlockingFlushToSSTable(); sstables = getCurrentColumnFamilyStore().getLiveSSTables(); getCurrentColumnFamilyStore().clearUnsafe(); diff --git a/test/unit/org/apache/cassandra/db/KeyCacheTest.java b/test/unit/org/apache/cassandra/db/KeyCacheTest.java index 1819b1811844..c1dab8c2ed74 100644 --- a/test/unit/org/apache/cassandra/db/KeyCacheTest.java +++ b/test/unit/org/apache/cassandra/db/KeyCacheTest.java @@ -102,7 +102,7 @@ private void testKeyCacheLoad(String cf) throws Exception // insert data and force to disk SchemaLoader.insertData(KEYSPACE1, cf, 0, 100); - store.forceBlockingFlush(); + store.forceBlockingFlushToSSTable(); // populate the cache readData(KEYSPACE1, cf, 0, 100); @@ -202,7 +202,7 @@ private void testKeyCacheLoadWithLostTable(String cf) throws Exception // insert data and force to disk SchemaLoader.insertData(KEYSPACE1, cf, 0, 100); - store.forceBlockingFlush(); + store.forceBlockingFlushToSSTable(); Collection firstFlushTables = ImmutableList.copyOf(store.getLiveSSTables()); @@ -212,7 +212,7 @@ private void testKeyCacheLoadWithLostTable(String cf) throws Exception // insert some new data and force to disk SchemaLoader.insertData(KEYSPACE1, cf, 100, 50); - store.forceBlockingFlush(); + store.forceBlockingFlushToSSTable(); // check that it's fine readData(KEYSPACE1, cf, 100, 50); @@ -273,7 +273,7 @@ private void testKeyCache(String cf) throws ExecutionException, InterruptedExcep new RowUpdateBuilder(cfs.metadata(), 0, "key2").clustering("2").build().applyUnsafe(); // to make sure we have SSTable - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); // reads to cache key position Util.getAll(Util.cmd(cfs, "key1").build()); diff --git a/test/unit/org/apache/cassandra/db/KeyspaceTest.java b/test/unit/org/apache/cassandra/db/KeyspaceTest.java index 3e088fbf6d36..df13edfe6ffb 100644 --- a/test/unit/org/apache/cassandra/db/KeyspaceTest.java +++ b/test/unit/org/apache/cassandra/db/KeyspaceTest.java @@ -83,7 +83,7 @@ public void testGetRowNoColumns() throws Throwable Util.assertEmpty(Util.cmd(cfs, "0").columns("c").includeRow(1).build()); if (round == 0) - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); } } @@ -118,7 +118,7 @@ public void testGetRowSingleColumn() throws Throwable } if (round == 0) - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); } } @@ -135,7 +135,7 @@ public void testGetSliceBloomFilterFalsePositive() throws Throwable for (String key : new String[]{"0", "2"}) Util.assertEmpty(Util.cmd(cfs, key).build()); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); for (String key : new String[]{"0", "2"}) Util.assertEmpty(Util.cmd(cfs, key).build()); @@ -207,7 +207,7 @@ public void testGetSliceWithCutoff() throws Throwable assertRowsInSlice(cfs, "0", 288, 299, 12, true, prefix); if (round == 0) - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); } } @@ -220,7 +220,7 @@ public void testReversedWithFlushing() throws Throwable for (int i = 0; i < 10; i++) execute("INSERT INTO %s (a, b, c) VALUES (?, ?, ?)", "0", i, i); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); for (int i = 10; i < 20; i++) { @@ -334,7 +334,7 @@ public void testGetSliceFromBasic() throws Throwable assertRowsInResult(cfs, command); if (round == 0) - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); } } @@ -357,7 +357,7 @@ public void testGetSliceWithExpiration() throws Throwable assertRowsInResult(cfs, command, 1); if (round == 0) - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); } } @@ -370,7 +370,7 @@ public void testGetSliceFromAdvanced() throws Throwable for (int i = 1; i < 7; i++) execute("INSERT INTO %s (a, b, c) VALUES (?, ?, ?)", "0", i, i); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); // overwrite three rows with -1 for (int i = 1; i < 4; i++) @@ -382,7 +382,7 @@ public void testGetSliceFromAdvanced() throws Throwable assertRowsInResult(cfs, command, -1, -1, 4); if (round == 0) - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); } } @@ -395,7 +395,7 @@ public void testGetSliceFromLarge() throws Throwable for (int i = 1000; i < 2000; i++) execute("INSERT INTO %s (a, b, c) VALUES (?, ?, ?)", "0", i, i); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); validateSliceLarge(cfs); @@ -423,7 +423,7 @@ public void testLimitSSTables() throws Throwable for (int i = 1000 + (j*100); i < 1000 + ((j+1)*100); i++) execute("INSERT INTO %s (a, b, c) VALUES (?, ?, ?) USING TIMESTAMP ?", "0", i, i, (long)i); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); } ((ClearableHistogram)cfs.metric.sstablesPerReadHistogram.cf).clear(); diff --git a/test/unit/org/apache/cassandra/db/NameSortTest.java b/test/unit/org/apache/cassandra/db/NameSortTest.java index 0b00f40ea2b1..517489c411d3 100644 --- a/test/unit/org/apache/cassandra/db/NameSortTest.java +++ b/test/unit/org/apache/cassandra/db/NameSortTest.java @@ -84,7 +84,7 @@ private void testNameSort(int N) throws IOException rub.build().applyUnsafe(); } validateNameSort(cfs); - keyspace.getColumnFamilyStore("Standard1").forceBlockingFlush(); + keyspace.getColumnFamilyStore("Standard1").forceBlockingFlushToSSTable(); validateNameSort(cfs); } diff --git a/test/unit/org/apache/cassandra/db/PartitionRangeReadTest.java b/test/unit/org/apache/cassandra/db/PartitionRangeReadTest.java index 9ae6c757b018..5d56ab4bd693 100644 --- a/test/unit/org/apache/cassandra/db/PartitionRangeReadTest.java +++ b/test/unit/org/apache/cassandra/db/PartitionRangeReadTest.java @@ -99,14 +99,14 @@ public void testCassandra6778() throws CharacterCodingException .add("val", "val1") .build() .applyUnsafe(); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); new RowUpdateBuilder(cfs.metadata(), 1, "k1") .clustering(new BigInteger(new byte[]{0, 0, 1})) .add("val", "val2") .build() .applyUnsafe(); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); // fetch by the first column name; we should get the second version of the column value Row row = Util.getOnlyRow(Util.cmd(cfs, "k1").includeRow(new BigInteger(new byte[]{1})).build()); @@ -158,7 +158,7 @@ public void testRangeSliceInclusionExclusion() throws Throwable builder.build().applyUnsafe(); } - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); ColumnMetadata cDef = cfs.metadata().getColumn(ByteBufferUtil.bytes("val")); diff --git a/test/unit/org/apache/cassandra/db/RangeTombstoneTest.java b/test/unit/org/apache/cassandra/db/RangeTombstoneTest.java index 3d1d00322d78..31eb28b57d5b 100644 --- a/test/unit/org/apache/cassandra/db/RangeTombstoneTest.java +++ b/test/unit/org/apache/cassandra/db/RangeTombstoneTest.java @@ -85,7 +85,7 @@ public void simpleQueryWithRangeTombstoneTest() throws Exception for (int i = 0; i < 40; i += 2) builder.newRow(i).add("val", i); builder.applyUnsafe(); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); new RowUpdateBuilder(cfs.metadata(), 1, key).addRangeTombstone(10, 22).build().applyUnsafe(); @@ -235,7 +235,7 @@ public void testTrackTimesPartitionTombstone() throws ExecutionException, Interr int nowInSec = FBUtilities.nowInSeconds(); new Mutation(PartitionUpdate.fullPartitionDelete(cfs.metadata(), Util.dk(key), 1000, nowInSec)).apply(); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); SSTableReader sstable = cfs.getLiveSSTables().iterator().next(); assertTimes(sstable.getSSTableMetadata(), 1000, 1000, nowInSec); @@ -257,7 +257,7 @@ public void testTrackTimesPartitionTombstoneWithData() throws ExecutionException key = "rt_times2"; int nowInSec = FBUtilities.nowInSeconds(); new Mutation(PartitionUpdate.fullPartitionDelete(cfs.metadata(), Util.dk(key), 1000, nowInSec)).apply(); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); SSTableReader sstable = cfs.getLiveSSTables().iterator().next(); assertTimes(sstable.getSSTableMetadata(), 999, 1000, Integer.MAX_VALUE); @@ -276,7 +276,7 @@ public void testTrackTimesRangeTombstone() throws ExecutionException, Interrupte int nowInSec = FBUtilities.nowInSeconds(); new RowUpdateBuilder(cfs.metadata(), nowInSec, 1000L, key).addRangeTombstone(1, 2).build().apply(); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); SSTableReader sstable = cfs.getLiveSSTables().iterator().next(); assertTimes(sstable.getSSTableMetadata(), 1000, 1000, nowInSec); @@ -298,9 +298,9 @@ public void testTrackTimesRangeTombstoneWithData() throws ExecutionException, In key = "rt_times2"; int nowInSec = FBUtilities.nowInSeconds(); new Mutation(PartitionUpdate.fullPartitionDelete(cfs.metadata(), Util.dk(key), 1000, nowInSec)).apply(); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); SSTableReader sstable = cfs.getLiveSSTables().iterator().next(); assertTimes(sstable.getSSTableMetadata(), 999, 1000, Integer.MAX_VALUE); cfs.forceMajorCompaction(); @@ -328,10 +328,10 @@ public void test7810() throws ExecutionException, InterruptedException for (int i = 10; i < 20; i ++) builder.newRow(i).add("val", i); builder.apply(); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); new RowUpdateBuilder(cfs.metadata(), 1, key).addRangeTombstone(10, 11).build().apply(); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); Thread.sleep(5); cfs.forceMajorCompaction(); @@ -350,10 +350,10 @@ public void test7808_1() throws ExecutionException, InterruptedException for (int i = 0; i < 40; i += 2) builder.newRow(i).add("val", i); builder.apply(); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); new Mutation(PartitionUpdate.fullPartitionDelete(cfs.metadata(), Util.dk(key), 1, 1)).apply(); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); Thread.sleep(5); cfs.forceMajorCompaction(); } @@ -370,13 +370,13 @@ public void test7808_2() throws ExecutionException, InterruptedException for (int i = 10; i < 20; i ++) builder.newRow(i).add("val", i); builder.apply(); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); new Mutation(PartitionUpdate.fullPartitionDelete(cfs.metadata(), Util.dk(key), 0, 0)).apply(); UpdateBuilder.create(cfs.metadata(), key).withTimestamp(1).newRow(5).add("val", 5).apply(); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); Thread.sleep(5); cfs.forceMajorCompaction(); assertEquals(1, Util.getOnlyPartitionUnfiltered(Util.cmd(cfs, key).build()).rowCount()); @@ -396,16 +396,16 @@ public void overlappingRangeTest() throws Exception for (int i = 0; i < 20; i++) builder.newRow(i).add("val", i); builder.applyUnsafe(); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); new RowUpdateBuilder(cfs.metadata(), 1, key).addRangeTombstone(5, 15).build().applyUnsafe(); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); new RowUpdateBuilder(cfs.metadata(), 1, key).addRangeTombstone(5, 10).build().applyUnsafe(); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); new RowUpdateBuilder(cfs.metadata(), 2, key).addRangeTombstone(5, 8).build().applyUnsafe(); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); Partition partition = Util.getOnlyPartitionUnfiltered(Util.cmd(cfs, key).build()); int nowInSec = FBUtilities.nowInSeconds(); @@ -447,11 +447,11 @@ public void reverseQueryTest() throws Exception String key = "k3"; UpdateBuilder.create(cfs.metadata(), key).withTimestamp(0).newRow(2).add("val", 2).applyUnsafe(); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); new RowUpdateBuilder(cfs.metadata(), 1, key).addRangeTombstone(0, 10).build().applyUnsafe(); UpdateBuilder.create(cfs.metadata(), key).withTimestamp(2).newRow(1).add("val", 1).applyUnsafe(); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); // Get the last value of the row FilteredPartition partition = Util.getOnlyPartition(Util.cmd(cfs, key).build()); @@ -508,10 +508,10 @@ public void testRowWithRangeTombstonesUpdatesSecondaryIndex() throws Exception for (int i = 0; i < 10; i++) builder.newRow(i).add("val", i); builder.applyUnsafe(); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); new RowUpdateBuilder(cfs.metadata(), 0, key).addRangeTombstone(0, 7).build().applyUnsafe(); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); assertEquals(10, index.rowsInserted.size()); @@ -538,10 +538,10 @@ public void testRangeTombstoneCompaction() throws Exception for (int i = 0; i < 10; i += 2) builder.newRow(i).add("val", i); builder.applyUnsafe(); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); new RowUpdateBuilder(cfs.metadata(), 0, key).addRangeTombstone(0, 7).build().applyUnsafe(); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); // there should be 2 sstables assertEquals(2, cfs.getLiveSSTables().size()); @@ -614,7 +614,7 @@ public void testOverwritesToDeletedColumns() throws Exception // now re-insert that column UpdateBuilder.create(cfs.metadata(), key).withTimestamp(2).newRow(1).add("val", 1).applyUnsafe(); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); // We should have 1 insert and 1 update to the indexed "1" column // CASSANDRA-6640 changed index update to just update, not insert then delete diff --git a/test/unit/org/apache/cassandra/db/ReadCommandTest.java b/test/unit/org/apache/cassandra/db/ReadCommandTest.java index 3bb30d961cb4..fd31b9b1c243 100644 --- a/test/unit/org/apache/cassandra/db/ReadCommandTest.java +++ b/test/unit/org/apache/cassandra/db/ReadCommandTest.java @@ -55,8 +55,9 @@ import org.apache.cassandra.locator.EndpointsForToken; import org.apache.cassandra.locator.InetAddressAndPort; import org.apache.cassandra.locator.ReplicaUtils; -import org.apache.cassandra.net.MessageOut; +import org.apache.cassandra.net.Message; import org.apache.cassandra.net.MessagingService; +import org.apache.cassandra.net.Verb; import org.apache.cassandra.repair.consistent.LocalSessionAccessor; import org.apache.cassandra.schema.CachingParams; import org.apache.cassandra.schema.KeyspaceParams; @@ -174,7 +175,7 @@ public void testPartitionRangeAbort() throws Exception .build() .apply(); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); new RowUpdateBuilder(cfs.metadata(), 0, ByteBufferUtil.bytes("key2")) .clustering("Column1") @@ -202,7 +203,7 @@ public void testSinglePartitionSliceAbort() throws Exception .build() .apply(); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); new RowUpdateBuilder(cfs.metadata(), 0, ByteBufferUtil.bytes("key")) .clustering("dd") @@ -233,7 +234,7 @@ public void testSinglePartitionNamesAbort() throws Exception .build() .apply(); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); new RowUpdateBuilder(cfs.metadata(), 0, ByteBufferUtil.bytes("key")) .clustering("dd") @@ -312,7 +313,7 @@ public void testSinglePartitionGroupMerge() throws Exception commands.add(SinglePartitionReadCommand.create(cfs.metadata(), nowInSeconds, columnFilter, rowFilter, DataLimits.NONE, Util.dk(data[1]), sliceFilter)); } - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); ReadQuery query = new SinglePartitionReadCommand.Group(commands, DataLimits.NONE); @@ -396,9 +397,9 @@ public void testSerializer() throws IOException int messagingVersion = MessagingService.current_version; FakeOutputStream out = new FakeOutputStream(); Tracing.instance.newSession(Tracing.TraceType.QUERY); - MessageOut messageOut = new MessageOut(MessagingService.Verb.READ, readCommand, ReadCommand.serializer); + Message messageOut = Message.out(Verb.READ_REQ, readCommand); long size = messageOut.serializedSize(messagingVersion); - messageOut.serialize(new WrappedDataOutputStreamPlus(out), messagingVersion); + Message.serializer.serialize(messageOut, new WrappedDataOutputStreamPlus(out), messagingVersion); Assert.assertEquals(size, out.count); } @@ -482,7 +483,7 @@ public void testCountDeletedRows() throws Exception DataLimits.NONE, Util.dk(data[1]), sliceFilter)); } - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); ReadQuery query = new SinglePartitionReadCommand.Group(commands, DataLimits.NONE); @@ -558,7 +559,7 @@ public void testCountWithNoDeletedRow() throws Exception DataLimits.NONE, Util.dk(data[1]), sliceFilter)); } - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); ReadQuery query = new SinglePartitionReadCommand.Group(commands, DataLimits.NONE); @@ -617,7 +618,7 @@ public void testSinglePartitionNamesSkipsOptimisationsIfTrackingRepairedData() .build() .apply(); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); new RowUpdateBuilder(cfs.metadata(), 1, ByteBufferUtil.bytes("key")) .clustering("dd") @@ -625,7 +626,7 @@ public void testSinglePartitionNamesSkipsOptimisationsIfTrackingRepairedData() .build() .apply(); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); List sstables = new ArrayList<>(cfs.getLiveSSTables()); assertEquals(2, sstables.size()); Collections.sort(sstables, SSTableReader.maxTimestampDescending); @@ -665,7 +666,7 @@ public void skipRowCacheIfTrackingRepairedData() .build() .apply(); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); ReadCommand readCommand = Util.cmd(cfs, Util.dk("key")).build(); assertTrue(cfs.isRowCacheEnabled()); @@ -738,7 +739,7 @@ private void testRepairedDataTracking(ColumnFamilyStore cfs, ReadCommand readCom .build() .apply(); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); new RowUpdateBuilder(cfs.metadata(), 1, ByteBufferUtil.bytes("key")) .clustering("dd") @@ -746,7 +747,7 @@ private void testRepairedDataTracking(ColumnFamilyStore cfs, ReadCommand readCom .build() .apply(); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); List sstables = new ArrayList<>(cfs.getLiveSSTables()); assertEquals(2, sstables.size()); sstables.forEach(sstable -> assertFalse(sstable.isRepaired() || sstable.isPendingRepair())); @@ -803,7 +804,7 @@ private void testRepairedDataTracking(ColumnFamilyStore cfs, ReadCommand readCom assertEquals(ByteBufferUtil.EMPTY_BYTE_BUFFER, digest); // now flush so we have an unrepaired table with the deletion and repeat the check - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); digest = performReadAndVerifyRepairedInfo(readCommand, 0, rowsPerPartition, false); assertEquals(ByteBufferUtil.EMPTY_BYTE_BUFFER, digest); } diff --git a/test/unit/org/apache/cassandra/db/ReadCommandVerbHandlerTest.java b/test/unit/org/apache/cassandra/db/ReadCommandVerbHandlerTest.java index b7e053b1db22..868227390427 100644 --- a/test/unit/org/apache/cassandra/db/ReadCommandVerbHandlerTest.java +++ b/test/unit/org/apache/cassandra/db/ReadCommandVerbHandlerTest.java @@ -19,17 +19,14 @@ package org.apache.cassandra.db; import java.net.UnknownHostException; -import java.util.Map; import java.util.Random; import java.util.UUID; -import com.google.common.collect.ImmutableMap; import org.junit.Before; import org.junit.BeforeClass; import org.junit.Test; import org.apache.cassandra.SchemaLoader; -import org.apache.cassandra.config.DatabaseDescriptor; import org.apache.cassandra.db.filter.ClusteringIndexSliceFilter; import org.apache.cassandra.db.filter.ColumnFilter; import org.apache.cassandra.db.filter.DataLimits; @@ -37,18 +34,17 @@ import org.apache.cassandra.exceptions.InvalidRequestException; import org.apache.cassandra.locator.InetAddressAndPort; import org.apache.cassandra.locator.TokenMetadata; -import org.apache.cassandra.net.IMessageSink; -import org.apache.cassandra.net.MessageIn; -import org.apache.cassandra.net.MessageOut; +import org.apache.cassandra.net.Message; +import org.apache.cassandra.net.MessageFlag; import org.apache.cassandra.net.MessagingService; -import org.apache.cassandra.net.ParameterType; +import org.apache.cassandra.net.ParamType; import org.apache.cassandra.schema.Schema; import org.apache.cassandra.schema.TableMetadata; import org.apache.cassandra.service.StorageService; import org.apache.cassandra.utils.ByteBufferUtil; import org.apache.cassandra.utils.FBUtilities; -import static org.apache.cassandra.Util.token; +import static org.apache.cassandra.net.Verb.*; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; @@ -84,19 +80,10 @@ public static void init() throws Throwable @Before public void setup() { - MessagingService.instance().clearMessageSinks(); - MessagingService.instance().addMessageSink(new IMessageSink() - { - public boolean allowOutgoingMessage(MessageOut message, int id, InetAddressAndPort to) - { - return false; - } - - public boolean allowIncomingMessage(MessageIn message, int id) - { - return false; - } - }); + MessagingService.instance().inboundSink.clear(); + MessagingService.instance().outboundSink.clear(); + MessagingService.instance().outboundSink.add((message, to) -> false); + MessagingService.instance().inboundSink.add((message) -> false); handler = new ReadCommandVerbHandler(); } @@ -104,59 +91,50 @@ public boolean allowIncomingMessage(MessageIn message, int id) @Test public void setRepairedDataTrackingFlagIfHeaderPresent() { - SinglePartitionReadCommand command = command(metadata); + ReadCommand command = command(metadata); assertFalse(command.isTrackingRepairedStatus()); - Map params = ImmutableMap.of(ParameterType.TRACK_REPAIRED_DATA, - MessagingService.ONE_BYTE); - handler.doVerb(MessageIn.create(peer(), - command, - params, - MessagingService.Verb.READ, - MessagingService.current_version), - messageId()); + + handler.doVerb(Message.builder(READ_REQ, command) + .from(peer()) + .withFlag(MessageFlag.TRACK_REPAIRED_DATA) + .withId(messageId()) + .build()); assertTrue(command.isTrackingRepairedStatus()); } @Test public void dontSetRepairedDataTrackingFlagUnlessHeaderPresent() { - SinglePartitionReadCommand command = command(metadata); + ReadCommand command = command(metadata); assertFalse(command.isTrackingRepairedStatus()); - Map params = ImmutableMap.of(ParameterType.TRACE_SESSION, - UUID.randomUUID()); - handler.doVerb(MessageIn.create(peer(), - command, - params, - MessagingService.Verb.READ, - MessagingService.current_version), - messageId()); + handler.doVerb(Message.builder(READ_REQ, command) + .from(peer()) + .withId(messageId()) + .withParam(ParamType.TRACE_SESSION, UUID.randomUUID()) + .build()); assertFalse(command.isTrackingRepairedStatus()); } @Test public void dontSetRepairedDataTrackingFlagIfHeadersEmpty() { - SinglePartitionReadCommand command = command(metadata); + ReadCommand command = command(metadata); assertFalse(command.isTrackingRepairedStatus()); - handler.doVerb(MessageIn.create(peer(), - command, - ImmutableMap.of(), - MessagingService.Verb.READ, - MessagingService.current_version), - messageId()); + handler.doVerb(Message.builder(READ_REQ, command) + .withId(messageId()) + .from(peer()) + .build()); assertFalse(command.isTrackingRepairedStatus()); } @Test (expected = InvalidRequestException.class) public void rejectsRequestWithNonMatchingTransientness() { - SinglePartitionReadCommand command = command(metadata_with_transient); - handler.doVerb(MessageIn.create(peer(), - command, - ImmutableMap.of(), - MessagingService.Verb.READ, - MessagingService.current_version), - messageId()); + ReadCommand command = command(metadata_with_transient); + handler.doVerb(Message.builder(READ_REQ, command) + .from(peer()) + .withId(messageId()) + .build()); } private static int messageId() diff --git a/test/unit/org/apache/cassandra/db/RecoveryManagerFlushedTest.java b/test/unit/org/apache/cassandra/db/RecoveryManagerFlushedTest.java index fc3494234021..f373a00e0ef0 100644 --- a/test/unit/org/apache/cassandra/db/RecoveryManagerFlushedTest.java +++ b/test/unit/org/apache/cassandra/db/RecoveryManagerFlushedTest.java @@ -115,7 +115,7 @@ public void testWithFlush() throws Exception Keyspace keyspace1 = Keyspace.open(KEYSPACE1); ColumnFamilyStore cfs = keyspace1.getColumnFamilyStore("Standard1"); logger.debug("forcing flush"); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); logger.debug("begin manual replay"); // replay the commit log (nothing on Standard1 should be replayed since everything was flushed, so only the row on Standard2 diff --git a/test/unit/org/apache/cassandra/db/RemoveCellTest.java b/test/unit/org/apache/cassandra/db/RemoveCellTest.java index 01fe2551f4b9..03381774e1cf 100644 --- a/test/unit/org/apache/cassandra/db/RemoveCellTest.java +++ b/test/unit/org/apache/cassandra/db/RemoveCellTest.java @@ -30,7 +30,7 @@ public void testDeleteCell() throws Throwable String tableName = createTable("CREATE TABLE %s (a int, b int, c int, PRIMARY KEY (a, b))"); ColumnFamilyStore cfs = Keyspace.open(KEYSPACE).getColumnFamilyStore(tableName); execute("INSERT INTO %s (a, b, c) VALUES (?, ?, ?) USING TIMESTAMP ?", 0, 0, 0, 0L); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); execute("DELETE c FROM %s USING TIMESTAMP ? WHERE a = ? AND b = ?", 1L, 0, 0); assertRows(execute("SELECT * FROM %s WHERE a = ? AND b = ?", 0, 0), row(0, 0, null)); assertRows(execute("SELECT c FROM %s WHERE a = ? AND b = ?", 0, 0), row(new Object[]{null})); diff --git a/test/unit/org/apache/cassandra/db/RowCacheTest.java b/test/unit/org/apache/cassandra/db/RowCacheTest.java index 5ca1eefb854e..b5440c2ac0a9 100644 --- a/test/unit/org/apache/cassandra/db/RowCacheTest.java +++ b/test/unit/org/apache/cassandra/db/RowCacheTest.java @@ -490,7 +490,7 @@ public void testSSTablesPerReadHistogramWhenRowCache() SchemaLoader.insertData(KEYSPACE_CACHED, CF_CACHED, 0, 100); //force flush for confidence that SSTables exists - cachedStore.forceBlockingFlush(); + cachedStore.forceBlockingFlushToSSTable(); ((ClearableHistogram)cachedStore.metric.sstablesPerReadHistogram.cf).clear(); diff --git a/test/unit/org/apache/cassandra/db/RowIterationTest.java b/test/unit/org/apache/cassandra/db/RowIterationTest.java index b0cd4fc1ca40..0e229aa2ed69 100644 --- a/test/unit/org/apache/cassandra/db/RowIterationTest.java +++ b/test/unit/org/apache/cassandra/db/RowIterationTest.java @@ -36,7 +36,7 @@ public void testRowIteration() throws Throwable ColumnFamilyStore cfs = Keyspace.open(KEYSPACE).getColumnFamilyStore(tableName); for (int i = 0; i < 10; i++) execute("INSERT INTO %s (a, b, c, d) VALUES (?, ?, ?, ?) USING TIMESTAMP ?", i, 0, i, i, (long)i); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); assertEquals(10, execute("SELECT * FROM %s").size()); } @@ -49,7 +49,7 @@ public void testRowIterationDeletionTime() throws Throwable execute("INSERT INTO %s (a, b) VALUES (?, ?) USING TIMESTAMP ?", 0, 0, 0L); execute("DELETE FROM %s USING TIMESTAMP ? WHERE a = ?", 0L, 0); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); // Delete row in second sstable with higher timestamp execute("INSERT INTO %s (a, b) VALUES (?, ?) USING TIMESTAMP ?", 0, 0, 1L); @@ -57,7 +57,7 @@ public void testRowIterationDeletionTime() throws Throwable int localDeletionTime = Util.getOnlyPartitionUnfiltered(Util.cmd(cfs).build()).partitionLevelDeletion().localDeletionTime(); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); DeletionTime dt = Util.getOnlyPartitionUnfiltered(Util.cmd(cfs).build()).partitionLevelDeletion(); assertEquals(1L, dt.markedForDeleteAt()); @@ -72,7 +72,7 @@ public void testRowIterationDeletion() throws Throwable // Delete a row in first sstable execute("DELETE FROM %s USING TIMESTAMP ? WHERE a = ?", 0L, 0); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); assertFalse(Util.getOnlyPartitionUnfiltered(Util.cmd(cfs).build()).isEmpty()); } diff --git a/test/unit/org/apache/cassandra/db/ScrubTest.java b/test/unit/org/apache/cassandra/db/ScrubTest.java index 28962dba3053..ffc1bbd17c43 100644 --- a/test/unit/org/apache/cassandra/db/ScrubTest.java +++ b/test/unit/org/apache/cassandra/db/ScrubTest.java @@ -460,7 +460,7 @@ protected void fillCF(ColumnFamilyStore cfs, int partitionsPerSSTable) new Mutation(update).applyUnsafe(); } - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); } public static void fillIndexCF(ColumnFamilyStore cfs, boolean composite, long ... values) @@ -484,7 +484,7 @@ public static void fillIndexCF(ColumnFamilyStore cfs, boolean composite, long .. new Mutation(builder.build()).applyUnsafe(); } - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); } protected void fillCounterCF(ColumnFamilyStore cfs, int partitionsPerSSTable) throws WriteTimeoutException @@ -497,7 +497,7 @@ protected void fillCounterCF(ColumnFamilyStore cfs, int partitionsPerSSTable) th new CounterMutation(new Mutation(update), ConsistencyLevel.ONE).apply(); } - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); } @Test @@ -509,14 +509,14 @@ public void testScrubColumnValidation() throws InterruptedException, RequestExec ColumnFamilyStore cfs = keyspace.getColumnFamilyStore("test_compact_static_columns"); QueryProcessor.executeInternal(String.format("INSERT INTO \"%s\".test_compact_static_columns (a, b, c, d) VALUES (123, c3db07e8-b602-11e3-bc6b-e0b9a54a6d93, true, 'foobar')", KEYSPACE)); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); CompactionManager.instance.performScrub(cfs, false, true, 2); QueryProcessor.process("CREATE TABLE \"Keyspace1\".test_scrub_validation (a text primary key, b int)", ConsistencyLevel.ONE); ColumnFamilyStore cfs2 = keyspace.getColumnFamilyStore("test_scrub_validation"); new Mutation(UpdateBuilder.create(cfs2.metadata(), "key").newRow().add("b", LongType.instance.decompose(1L)).build()).apply(); - cfs2.forceBlockingFlush(); + cfs2.forceBlockingFlushToSSTable(); CompactionManager.instance.performScrub(cfs2, false, false, 2); } diff --git a/test/unit/org/apache/cassandra/db/SecondaryIndexTest.java b/test/unit/org/apache/cassandra/db/SecondaryIndexTest.java index 5dadc115f0f2..62aac898f7b8 100644 --- a/test/unit/org/apache/cassandra/db/SecondaryIndexTest.java +++ b/test/unit/org/apache/cassandra/db/SecondaryIndexTest.java @@ -56,6 +56,7 @@ public class SecondaryIndexTest { public static final String KEYSPACE1 = "SecondaryIndexTest1"; public static final String WITH_COMPOSITE_INDEX = "WithCompositeIndex"; + public static final String WITH_MULTIPLE_COMPOSITE_INDEX = "WithMultipleCompositeIndex"; public static final String WITH_KEYS_INDEX = "WithKeysIndex"; public static final String COMPOSITE_INDEX_TO_BE_ADDED = "CompositeIndexToBeAdded"; @@ -67,6 +68,7 @@ public static void defineSchema() throws ConfigurationException KeyspaceParams.simple(1), SchemaLoader.compositeIndexCFMD(KEYSPACE1, WITH_COMPOSITE_INDEX, true, true).gcGraceSeconds(0), SchemaLoader.compositeIndexCFMD(KEYSPACE1, COMPOSITE_INDEX_TO_BE_ADDED, false).gcGraceSeconds(0), + SchemaLoader.compositeMultipleIndexCFMD(KEYSPACE1, WITH_MULTIPLE_COMPOSITE_INDEX).gcGraceSeconds(0), SchemaLoader.keysIndexCFMD(KEYSPACE1, WITH_KEYS_INDEX, true).gcGraceSeconds(0)); } @@ -75,6 +77,7 @@ public void truncateCFS() { Keyspace.open(KEYSPACE1).getColumnFamilyStore(WITH_COMPOSITE_INDEX).truncateBlocking(); Keyspace.open(KEYSPACE1).getColumnFamilyStore(COMPOSITE_INDEX_TO_BE_ADDED).truncateBlocking(); + Keyspace.open(KEYSPACE1).getColumnFamilyStore(WITH_MULTIPLE_COMPOSITE_INDEX).truncateBlocking(); Keyspace.open(KEYSPACE1).getColumnFamilyStore(WITH_KEYS_INDEX).truncateBlocking(); } @@ -297,7 +300,7 @@ public void testDeleteOfInconsistentValuesInKeysIndex() throws Exception new RowUpdateBuilder(cfs.metadata(), 1, "k1").noRowMarker().add("birthdate", 1L).build().applyUnsafe(); // force a flush, so our index isn't being read from a memtable - keyspace.getColumnFamilyStore(WITH_KEYS_INDEX).forceBlockingFlush(); + keyspace.getColumnFamilyStore(WITH_KEYS_INDEX).forceBlockingFlushToSSTable(); // now apply another update, but force the index update to be skipped keyspace.apply(new RowUpdateBuilder(cfs.metadata(), 2, "k1").noRowMarker().add("birthdate", 2L).build(), @@ -353,7 +356,7 @@ private void runDeleteOfInconsistentValuesFromCompositeIndexTest(boolean isStati assertIndexedOne(cfs, col, 10l); // force a flush and retry the query, so our index isn't being read from a memtable - keyspace.getColumnFamilyStore(cfName).forceBlockingFlush(); + keyspace.getColumnFamilyStore(cfName).forceBlockingFlushToSSTable(); assertIndexedOne(cfs, col, 10l); // now apply another update, but force the index update to be skipped @@ -519,10 +522,33 @@ public void testKeysSearcherSimple() throws Exception new RowUpdateBuilder(cfs.metadata(), 0, "k" + i).noRowMarker().add("birthdate", 1l).build().applyUnsafe(); assertIndexedCount(cfs, ByteBufferUtil.bytes("birthdate"), 1l, 10); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); assertIndexedCount(cfs, ByteBufferUtil.bytes("birthdate"), 1l, 10); } + @Test + public void testSelectivityWithMultipleIndexes() + { + ColumnFamilyStore cfs = Keyspace.open(KEYSPACE1).getColumnFamilyStore(WITH_MULTIPLE_COMPOSITE_INDEX); + + // creates rows such that birthday_index has 1 partition (key = 1L) with 4 rows -- mean row count = 4, and notbirthdate_index has 2 partitions with 2 rows each -- mean row count = 2 + new RowUpdateBuilder(cfs.metadata(), 0, "k1").clustering("c").add("birthdate", 1L).add("notbirthdate", 2L).build().applyUnsafe(); + new RowUpdateBuilder(cfs.metadata(), 0, "k2").clustering("c").add("birthdate", 1L).add("notbirthdate", 2L).build().applyUnsafe(); + new RowUpdateBuilder(cfs.metadata(), 0, "k3").clustering("c").add("birthdate", 1L).add("notbirthdate", 3L).build().applyUnsafe(); + new RowUpdateBuilder(cfs.metadata(), 0, "k4").clustering("c").add("birthdate", 1L).add("notbirthdate", 3L).build().applyUnsafe(); + + cfs.forceBlockingFlushToSSTable(); + ReadCommand rc = Util.cmd(cfs) + .fromKeyIncl("k1") + .toKeyIncl("k3") + .columns("birthdate") + .filterOn("birthdate", Operator.EQ, 1L) + .filterOn("notbirthdate", Operator.EQ, 0L) + .build(); + + assertEquals("notbirthdate_key_index", rc.indexMetadata().name); + } + private void assertIndexedNone(ColumnFamilyStore cfs, ByteBuffer col, Object val) { assertIndexedCount(cfs, col, val, 0); diff --git a/test/unit/org/apache/cassandra/db/SinglePartitionReadCommandCQLTest.java b/test/unit/org/apache/cassandra/db/SinglePartitionReadCommandCQLTest.java index 1c891ec2b2c6..2bbe43b8c799 100644 --- a/test/unit/org/apache/cassandra/db/SinglePartitionReadCommandCQLTest.java +++ b/test/unit/org/apache/cassandra/db/SinglePartitionReadCommandCQLTest.java @@ -31,10 +31,10 @@ public void partitionLevelDeletionTest() throws Throwable { createTable("CREATE TABLE %s (bucket_id TEXT,name TEXT,data TEXT,PRIMARY KEY (bucket_id, name))"); execute("insert into %s (bucket_id, name, data) values ('8772618c9009cf8f5a5e0c18', 'test', 'hello')"); - getCurrentColumnFamilyStore().forceBlockingFlush(); + getCurrentColumnFamilyStore().forceBlockingFlushToSSTable(); execute("insert into %s (bucket_id, name, data) values ('8772618c9009cf8f5a5e0c19', 'test2', 'hello');"); execute("delete from %s where bucket_id = '8772618c9009cf8f5a5e0c18'"); - getCurrentColumnFamilyStore().forceBlockingFlush(); + getCurrentColumnFamilyStore().forceBlockingFlushToSSTable(); UntypedResultSet res = execute("select * from %s where bucket_id = '8772618c9009cf8f5a5e0c18' and name = 'test'"); assertTrue(res.isEmpty()); } diff --git a/test/unit/org/apache/cassandra/db/SinglePartitionSliceCommandTest.java b/test/unit/org/apache/cassandra/db/SinglePartitionSliceCommandTest.java index 67fd314a066e..8e50ad584534 100644 --- a/test/unit/org/apache/cassandra/db/SinglePartitionSliceCommandTest.java +++ b/test/unit/org/apache/cassandra/db/SinglePartitionSliceCommandTest.java @@ -25,7 +25,6 @@ import java.io.IOException; import java.nio.ByteBuffer; import java.util.ArrayList; -import java.util.Collections; import java.util.Iterator; import java.util.List; import java.util.concurrent.TimeUnit; @@ -48,7 +47,6 @@ import org.apache.cassandra.Util; import org.apache.cassandra.cql3.QueryOptions; import org.apache.cassandra.cql3.QueryProcessor; -import org.apache.cassandra.cql3.UntypedResultSet; import org.apache.cassandra.cql3.statements.SelectStatement; import org.apache.cassandra.db.filter.AbstractClusteringIndexFilter; import org.apache.cassandra.db.filter.ClusteringIndexNamesFilter; @@ -182,7 +180,7 @@ private void testMultiNamesOrSlicesCommand(boolean flush, boolean isSlice) ck1)); if (flush) - Keyspace.open(KEYSPACE).getColumnFamilyStore(TABLE_SCLICES).forceBlockingFlush(); + Keyspace.open(KEYSPACE).getColumnFamilyStore(TABLE_SCLICES).forceBlockingFlushToSSTable(); AbstractClusteringIndexFilter clusteringFilter = createClusteringFilter(uniqueCk1, uniqueCk2, isSlice); ReadCommand cmd = SinglePartitionReadCommand.create(CFM_SLICES, @@ -301,7 +299,7 @@ public void staticColumnsAreReturned() throws IOException } // check (de)serialized iterator for sstable static cell - Schema.instance.getColumnFamilyStoreInstance(metadata.id).forceBlockingFlush(); + Schema.instance.getColumnFamilyStoreInstance(metadata.id).forceBlockingFlushToSSTable(); try (ReadExecutionController executionController = cmd.executionController(); UnfilteredPartitionIterator pi = cmd.executeLocally(executionController)) { response = ReadResponse.createDataResponse(pi, cmd); @@ -391,7 +389,7 @@ public void sstableFiltering() QueryProcessor.executeOnceInternal("INSERT INTO ks.legacy_mc_inaccurate_min_max (k, c1, c2, c3, v) VALUES (100, 2, 2, 2, 2)"); QueryProcessor.executeOnceInternal("DELETE FROM ks.legacy_mc_inaccurate_min_max WHERE k=100 AND c1=1"); assertQueryReturnsSingleRT("SELECT * FROM ks.legacy_mc_inaccurate_min_max WHERE k=100 AND c1=1 AND c2=1"); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); assertQueryReturnsSingleRT("SELECT * FROM ks.legacy_mc_inaccurate_min_max WHERE k=100 AND c1=1 AND c2=1"); assertQueryReturnsSingleRT("SELECT * FROM ks.legacy_mc_inaccurate_min_max WHERE k=100 AND c1=1 AND c2=1 AND c3=1"); // clustering names @@ -407,7 +405,7 @@ public void sstableFiltering() new Mutation(builder.build()).apply(); assertQueryReturnsSingleRT("SELECT * FROM ks.legacy_mc_inaccurate_min_max WHERE k=100 AND c1=3 AND c2=2"); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); assertQueryReturnsSingleRT("SELECT * FROM ks.legacy_mc_inaccurate_min_max WHERE k=100 AND c1=3 AND c2=2"); assertQueryReturnsSingleRT("SELECT * FROM ks.legacy_mc_inaccurate_min_max WHERE k=100 AND c1=3 AND c2=2 AND c3=2"); // clustering names diff --git a/test/unit/org/apache/cassandra/db/TimeSortTest.java b/test/unit/org/apache/cassandra/db/TimeSortTest.java index 8ae05ea9578f..c16916f51e87 100644 --- a/test/unit/org/apache/cassandra/db/TimeSortTest.java +++ b/test/unit/org/apache/cassandra/db/TimeSortTest.java @@ -36,7 +36,7 @@ public void testMixedSources() throws Throwable ColumnFamilyStore cfs = Keyspace.open(KEYSPACE).getColumnFamilyStore(tableName); execute("INSERT INTO %s (a, b, c) VALUES (?, ?, ?) USING TIMESTAMP ?", 0, 100, 0, 100L); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); execute("INSERT INTO %s (a, b, c) VALUES (?, ?, ?) USING TIMESTAMP ?", 0, 0, 1, 0L); assertRows(execute("SELECT * FROM %s WHERE a = ? AND b >= ? LIMIT 1000", 0, 10), row(0, 100, 0)); @@ -53,7 +53,7 @@ public void testTimeSort() throws Throwable execute("INSERT INTO %s (a, b, c) VALUES (?, ?, ?) USING TIMESTAMP ?", i, j * 2, 0, (long)j * 2); validateTimeSort(); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); validateTimeSort(); // interleave some new data to test memtable + sstable diff --git a/test/unit/org/apache/cassandra/db/VerifyTest.java b/test/unit/org/apache/cassandra/db/VerifyTest.java index df2acb4fb179..6969567678dc 100644 --- a/test/unit/org/apache/cassandra/db/VerifyTest.java +++ b/test/unit/org/apache/cassandra/db/VerifyTest.java @@ -24,7 +24,6 @@ import org.apache.cassandra.Util; import org.apache.cassandra.cache.ChunkCache; import org.apache.cassandra.UpdateBuilder; -import org.apache.cassandra.db.compaction.AbstractCompactionStrategy; import org.apache.cassandra.db.compaction.CompactionManager; import org.apache.cassandra.db.compaction.Verifier; import org.apache.cassandra.db.marshal.UUIDType; @@ -35,7 +34,6 @@ import org.apache.cassandra.exceptions.ConfigurationException; import org.apache.cassandra.exceptions.WriteTimeoutException; import org.apache.cassandra.io.FSWriteError; -import org.apache.cassandra.io.compress.CorruptBlockException; import org.apache.cassandra.io.sstable.Component; import org.apache.cassandra.io.sstable.CorruptSSTableException; import org.apache.cassandra.io.sstable.format.SSTableReader; @@ -52,7 +50,6 @@ import org.junit.runner.RunWith; import java.io.*; -import java.net.UnknownHostException; import java.nio.file.Files; import java.util.ArrayList; import java.util.Collections; @@ -702,7 +699,7 @@ protected void fillCF(ColumnFamilyStore cfs, int partitionsPerSSTable) .apply(); } - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); } protected void fillCounterCF(ColumnFamilyStore cfs, int partitionsPerSSTable) throws WriteTimeoutException @@ -714,7 +711,7 @@ protected void fillCounterCF(ColumnFamilyStore cfs, int partitionsPerSSTable) th .apply(); } - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); } protected long simpleFullChecksum(String filename) throws IOException diff --git a/test/unit/org/apache/cassandra/db/columniterator/SSTableReverseIteratorTest.java b/test/unit/org/apache/cassandra/db/columniterator/SSTableReverseIteratorTest.java index 9040f1197cfa..0a423f48b790 100644 --- a/test/unit/org/apache/cassandra/db/columniterator/SSTableReverseIteratorTest.java +++ b/test/unit/org/apache/cassandra/db/columniterator/SSTableReverseIteratorTest.java @@ -81,7 +81,7 @@ public void emptyBlockTolerance() QueryProcessor.executeInternal(String.format("UPDATE %s.%s SET v1=? WHERE k=? AND c=?", KEYSPACE, table), bytes(0x20000), key, 2); QueryProcessor.executeInternal(String.format("UPDATE %s.%s SET v1=? WHERE k=? AND c=?", KEYSPACE, table), bytes(0x20000), key, 3); - tbl.forceBlockingFlush(); + tbl.forceBlockingFlushToSSTable(); SSTableReader sstable = Iterables.getOnlyElement(tbl.getLiveSSTables()); DecoratedKey dk = tbl.getPartitioner().decorateKey(Int32Type.instance.decompose(key)); RowIndexEntry indexEntry = sstable.getPosition(dk, SSTableReader.Operator.EQ); diff --git a/test/unit/org/apache/cassandra/db/commitlog/AbstractCommitLogServiceTest.java b/test/unit/org/apache/cassandra/db/commitlog/AbstractCommitLogServiceTest.java index bc5cb29e25df..741b1454b5c9 100644 --- a/test/unit/org/apache/cassandra/db/commitlog/AbstractCommitLogServiceTest.java +++ b/test/unit/org/apache/cassandra/db/commitlog/AbstractCommitLogServiceTest.java @@ -28,7 +28,6 @@ import org.apache.cassandra.config.Config; import org.apache.cassandra.config.DatabaseDescriptor; import org.apache.cassandra.db.commitlog.AbstractCommitLogService.SyncRunnable; -import org.apache.cassandra.utils.Clock; import org.apache.cassandra.utils.FreeRunningClock; import static org.apache.cassandra.db.commitlog.AbstractCommitLogService.DEFAULT_MARKER_INTERVAL_MILLIS; diff --git a/test/unit/org/apache/cassandra/db/commitlog/CDCTestReplayer.java b/test/unit/org/apache/cassandra/db/commitlog/CDCTestReplayer.java deleted file mode 100644 index 18bc6e0a957f..000000000000 --- a/test/unit/org/apache/cassandra/db/commitlog/CDCTestReplayer.java +++ /dev/null @@ -1,76 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.cassandra.db.commitlog; - -import java.io.File; -import java.io.IOException; - -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import org.apache.cassandra.config.DatabaseDescriptor; -import org.apache.cassandra.db.Mutation; -import org.apache.cassandra.db.rows.SerializationHelper; -import org.apache.cassandra.io.util.DataInputBuffer; -import org.apache.cassandra.io.util.RebufferingInputStream; - -/** - * Utility class that flags the replayer as having seen a CDC mutation and calculates offset but doesn't apply mutations - */ -public class CDCTestReplayer extends CommitLogReplayer -{ - private static final Logger logger = LoggerFactory.getLogger(CDCTestReplayer.class); - - public CDCTestReplayer() throws IOException - { - super(CommitLog.instance, CommitLogPosition.NONE, null, ReplayFilter.create()); - CommitLog.instance.sync(true); - commitLogReader = new CommitLogTestReader(); - } - - public void examineCommitLog() throws IOException - { - replayFiles(new File(DatabaseDescriptor.getCommitLogLocation()).listFiles()); - } - - private class CommitLogTestReader extends CommitLogReader - { - @Override - protected void readMutation(CommitLogReadHandler handler, - byte[] inputBuffer, - int size, - CommitLogPosition minPosition, - final int entryLocation, - final CommitLogDescriptor desc) throws IOException - { - RebufferingInputStream bufIn = new DataInputBuffer(inputBuffer, 0, size); - Mutation mutation; - try - { - mutation = Mutation.serializer.deserialize(bufIn, desc.getMessagingVersion(), SerializationHelper.Flag.LOCAL); - if (mutation.trackedByCDC()) - sawCDCMutation = true; - } - catch (IOException e) - { - // Test fails. - throw new AssertionError(e); - } - } - } -} diff --git a/test/unit/org/apache/cassandra/db/commitlog/CommitLogCQLTest.java b/test/unit/org/apache/cassandra/db/commitlog/CommitLogCQLTest.java index 72356003719b..29f7a02842c6 100644 --- a/test/unit/org/apache/cassandra/db/commitlog/CommitLogCQLTest.java +++ b/test/unit/org/apache/cassandra/db/commitlog/CommitLogCQLTest.java @@ -31,11 +31,11 @@ public void testTruncateSegmentDiscard() throws Throwable execute("INSERT INTO %s (idx, data) VALUES (?, ?)", 15, Integer.toString(17)); - Collection active = new ArrayList<>(CommitLog.instance.segmentManager.getActiveSegments()); + Collection active = new ArrayList<>(CommitLog.instance.segmentManager.getSegmentsForUnflushedTables()); CommitLog.instance.forceRecycleAllSegments(); // If one of the previous segments remains, it wasn't clean. - active.retainAll(CommitLog.instance.segmentManager.getActiveSegments()); + active.retainAll(CommitLog.instance.segmentManager.getSegmentsForUnflushedTables()); assert active.isEmpty(); } } diff --git a/test/unit/org/apache/cassandra/db/commitlog/CommitLogReaderTest.java b/test/unit/org/apache/cassandra/db/commitlog/CommitLogReaderTest.java index ca76e45ac8ad..8d5430f60d3d 100644 --- a/test/unit/org/apache/cassandra/db/commitlog/CommitLogReaderTest.java +++ b/test/unit/org/apache/cassandra/db/commitlog/CommitLogReaderTest.java @@ -261,7 +261,7 @@ CommitLogPosition populateData(int entryCount) throws Throwable for (int i = midpoint; i < entryCount; i++) execute("INSERT INTO %s (idx, data) VALUES (?, ?)", i, Integer.toString(i)); - Keyspace.open(keyspace()).getColumnFamilyStore(currentTable()).forceBlockingFlush(); + Keyspace.open(keyspace()).getColumnFamilyStore(currentTable()).forceBlockingFlushToSSTable(); return result; } } diff --git a/test/unit/org/apache/cassandra/db/commitlog/CommitLogSegmentManagerCDCTest.java b/test/unit/org/apache/cassandra/db/commitlog/CommitLogSegmentAllocatorCDCTest.java similarity index 85% rename from test/unit/org/apache/cassandra/db/commitlog/CommitLogSegmentManagerCDCTest.java rename to test/unit/org/apache/cassandra/db/commitlog/CommitLogSegmentAllocatorCDCTest.java index 8c0647c83917..19ddb8095363 100644 --- a/test/unit/org/apache/cassandra/db/commitlog/CommitLogSegmentManagerCDCTest.java +++ b/test/unit/org/apache/cassandra/db/commitlog/CommitLogSegmentAllocatorCDCTest.java @@ -18,11 +18,16 @@ package org.apache.cassandra.db.commitlog; -import java.io.*; -import java.nio.ByteBuffer; +import java.io.BufferedReader; +import java.io.File; +import java.io.FileInputStream; +import java.io.FileReader; +import java.io.IOException; +import java.io.InputStreamReader; import java.nio.file.Files; import java.nio.file.Path; -import java.util.*; +import java.util.ArrayList; +import java.util.List; import org.junit.Assert; import org.junit.Before; @@ -38,10 +43,8 @@ import org.apache.cassandra.io.util.FileUtils; import org.apache.cassandra.schema.TableMetadata; -public class CommitLogSegmentManagerCDCTest extends CQLTester +public class CommitLogSegmentAllocatorCDCTest extends CQLTester { - private static final Random random = new Random(); - @BeforeClass public static void setUpClass() { @@ -56,14 +59,14 @@ public void beforeTest() throws Throwable // Need to clean out any files from previous test runs. Prevents flaky test failures. CommitLog.instance.stopUnsafe(true); CommitLog.instance.start(); - ((CommitLogSegmentManagerCDC)CommitLog.instance.segmentManager).updateCDCTotalSize(); + CommitLogTestUtils.updateCDCTotalSize(CommitLog.instance.segmentManager); } @Test public void testCDCWriteFailure() throws Throwable { createTable("CREATE TABLE %s (idx int, data text, primary key(idx)) WITH cdc=true;"); - CommitLogSegmentManagerCDC cdcMgr = (CommitLogSegmentManagerCDC)CommitLog.instance.segmentManager; + CommitLogSegmentManager cdcMgr = CommitLog.instance.segmentManager; TableMetadata cfm = currentTableMetadata(); // Confirm that logic to check for whether or not we can allocate new CDC segments works @@ -78,7 +81,7 @@ public void testCDCWriteFailure() throws Throwable for (int i = 0; i < 100; i++) { new RowUpdateBuilder(cfm, 0, i) - .add("data", randomizeBuffer(DatabaseDescriptor.getCommitLogSegmentSize() / 3)) + .add("data", CommitLogTestUtils.randomizeBuffer(DatabaseDescriptor.getCommitLogSegmentSize() / 3)) .build().apply(); } Assert.fail("Expected CDCWriteException from full CDC but did not receive it."); @@ -87,25 +90,25 @@ public void testCDCWriteFailure() throws Throwable { // expected, do nothing } - expectCurrentCDCState(CDCState.FORBIDDEN); + CommitLogTestUtils.expectCurrentCDCState(CDCState.FORBIDDEN); // Confirm we can create a non-cdc table and write to it even while at cdc capacity createTable("CREATE TABLE %s (idx int, data text, primary key(idx)) WITH cdc=false;"); execute("INSERT INTO %s (idx, data) VALUES (1, '1');"); // Confirm that, on flush+recyle, we see files show up in cdc_raw - Keyspace.open(keyspace()).getColumnFamilyStore(currentTable()).forceBlockingFlush(); + Keyspace.open(keyspace()).getColumnFamilyStore(currentTable()).forceBlockingFlushToSSTable(); CommitLog.instance.forceRecycleAllSegments(); cdcMgr.awaitManagementTasksCompletion(); - Assert.assertTrue("Expected files to be moved to overflow.", getCDCRawCount() > 0); + Assert.assertTrue("Expected files to be moved to overflow.", CommitLogTestUtils.getCDCRawCount() > 0); // Simulate a CDC consumer reading files then deleting them for (File f : new File(DatabaseDescriptor.getCDCLogLocation()).listFiles()) FileUtils.deleteWithConfirm(f); - // Update size tracker to reflect deleted files. Should flip flag on current allocatingFrom to allow. - cdcMgr.updateCDCTotalSize(); - expectCurrentCDCState(CDCState.PERMITTED); + // Update size tracker to reflect deleted files. Should flip flag on current active segment to allow. + CommitLogTestUtils.updateCDCTotalSize(cdcMgr); + CommitLogTestUtils.expectCurrentCDCState(CDCState.PERMITTED); } finally { @@ -116,7 +119,7 @@ public void testCDCWriteFailure() throws Throwable @Test public void testSegmentFlaggingOnCreation() throws Throwable { - CommitLogSegmentManagerCDC cdcMgr = (CommitLogSegmentManagerCDC)CommitLog.instance.segmentManager; + CommitLogSegmentManager cdcMgr = CommitLog.instance.segmentManager; String ct = createTable("CREATE TABLE %s (idx int, data text, primary key(idx)) WITH cdc=true;"); int origSize = DatabaseDescriptor.getCDCSpaceInMB(); @@ -130,23 +133,23 @@ public void testSegmentFlaggingOnCreation() throws Throwable for (int i = 0; i < 1000; i++) { new RowUpdateBuilder(ccfm, 0, i) - .add("data", randomizeBuffer(DatabaseDescriptor.getCommitLogSegmentSize() / 3)) + .add("data", CommitLogTestUtils.randomizeBuffer(DatabaseDescriptor.getCommitLogSegmentSize() / 3)) .build().apply(); } Assert.fail("Expected CDCWriteException from full CDC but did not receive it."); } catch (CDCWriteException e) { } - expectCurrentCDCState(CDCState.FORBIDDEN); + CommitLogTestUtils.expectCurrentCDCState(CDCState.FORBIDDEN); CommitLog.instance.forceRecycleAllSegments(); cdcMgr.awaitManagementTasksCompletion(); // Delete all files in cdc_raw for (File f : new File(DatabaseDescriptor.getCDCLogLocation()).listFiles()) f.delete(); - cdcMgr.updateCDCTotalSize(); + CommitLogTestUtils.updateCDCTotalSize(cdcMgr); // Confirm cdc update process changes flag on active segment - expectCurrentCDCState(CDCState.PERMITTED); + CommitLogTestUtils.expectCurrentCDCState(CDCState.PERMITTED); // Clear out archived CDC files for (File f : new File(DatabaseDescriptor.getCDCLogLocation()).listFiles()) { @@ -164,11 +167,11 @@ public void testCDCIndexFileWriteOnSync() throws IOException { createTable("CREATE TABLE %s (idx int, data text, primary key(idx)) WITH cdc=true;"); new RowUpdateBuilder(currentTableMetadata(), 0, 1) - .add("data", randomizeBuffer(DatabaseDescriptor.getCommitLogSegmentSize() / 3)) + .add("data", CommitLogTestUtils.randomizeBuffer(DatabaseDescriptor.getCommitLogSegmentSize() / 3)) .build().apply(); CommitLog.instance.sync(true); - CommitLogSegment currentSegment = CommitLog.instance.segmentManager.allocatingFrom(); + CommitLogSegment currentSegment = CommitLog.instance.segmentManager.getActiveSegment(); int syncOffset = currentSegment.lastSyncedOffset; // Confirm index file is written @@ -187,7 +190,7 @@ public void testCDCIndexFileWriteOnSync() throws IOException public void testCompletedFlag() throws IOException { createTable("CREATE TABLE %s (idx int, data text, primary key(idx)) WITH cdc=true;"); - CommitLogSegment initialSegment = CommitLog.instance.segmentManager.allocatingFrom(); + CommitLogSegment initialSegment = CommitLog.instance.segmentManager.getActiveSegment(); Integer originalCDCSize = DatabaseDescriptor.getCDCSpaceInMB(); DatabaseDescriptor.setCDCSpaceInMB(8); @@ -196,7 +199,7 @@ public void testCompletedFlag() throws IOException for (int i = 0; i < 1000; i++) { new RowUpdateBuilder(currentTableMetadata(), 0, 1) - .add("data", randomizeBuffer(DatabaseDescriptor.getCommitLogSegmentSize() / 3)) + .add("data", CommitLogTestUtils.randomizeBuffer(DatabaseDescriptor.getCommitLogSegmentSize() / 3)) .build().apply(); } } @@ -228,9 +231,9 @@ public void testDeleteLinkOnDiscardNoCDC() throws Throwable { createTable("CREATE TABLE %s (idx int, data text, primary key(idx)) WITH cdc=false;"); new RowUpdateBuilder(currentTableMetadata(), 0, 1) - .add("data", randomizeBuffer(DatabaseDescriptor.getCommitLogSegmentSize() / 3)) + .add("data", CommitLogTestUtils.randomizeBuffer(DatabaseDescriptor.getCommitLogSegmentSize() / 3)) .build().apply(); - CommitLogSegment currentSegment = CommitLog.instance.segmentManager.allocatingFrom(); + CommitLogSegment currentSegment = CommitLog.instance.segmentManager.getActiveSegment(); // Confirm that, with no CDC data present, we've hard-linked but have no index file Path linked = new File(DatabaseDescriptor.getCDCLogLocation(), currentSegment.logFile.getName()).toPath(); @@ -253,12 +256,12 @@ public void testDeleteLinkOnDiscardNoCDC() throws Throwable public void testRetainLinkOnDiscardCDC() throws Throwable { createTable("CREATE TABLE %s (idx int, data text, primary key(idx)) WITH cdc=true;"); - CommitLogSegment currentSegment = CommitLog.instance.segmentManager.allocatingFrom(); + CommitLogSegment currentSegment = CommitLog.instance.segmentManager.getActiveSegment(); File cdcIndexFile = currentSegment.getCDCIndexFile(); Assert.assertFalse("Expected no index file before flush but found: " + cdcIndexFile, cdcIndexFile.exists()); new RowUpdateBuilder(currentTableMetadata(), 0, 1) - .add("data", randomizeBuffer(DatabaseDescriptor.getCommitLogSegmentSize() / 3)) + .add("data", CommitLogTestUtils.randomizeBuffer(DatabaseDescriptor.getCommitLogSegmentSize() / 3)) .build().apply(); Path linked = new File(DatabaseDescriptor.getCDCLogLocation(), currentSegment.logFile.getName()).toPath(); @@ -290,7 +293,7 @@ public void testReplayLogic() throws IOException for (int i = 0; i < 1000; i++) { new RowUpdateBuilder(ccfm, 0, i) - .add("data", randomizeBuffer(DatabaseDescriptor.getCommitLogSegmentSize() / 3)) + .add("data", CommitLogTestUtils.randomizeBuffer(DatabaseDescriptor.getCommitLogSegmentSize() / 3)) .build().apply(); } Assert.fail("Expected CDCWriteException from full CDC but did not receive it."); @@ -325,8 +328,8 @@ public void testReplayLogic() throws IOException CommitLog.instance.start(); CommitLog.instance.segmentManager.awaitManagementTasksCompletion(); } - CDCTestReplayer replayer = new CDCTestReplayer(); - replayer.examineCommitLog(); + CommitLogTestUtils.MutationCountingReplayer replayer = new CommitLogTestUtils.MutationCountingReplayer(); + replayer.replayExistingCommitLog(); // Rough sanity check -> should be files there now. Assert.assertTrue("Expected non-zero number of files in CDC folder after restart.", @@ -422,28 +425,4 @@ public boolean equals(Object other) return fileName.equals(cid.fileName) && offset == cid.offset; } } - - private ByteBuffer randomizeBuffer(int size) - { - byte[] toWrap = new byte[size]; - random.nextBytes(toWrap); - return ByteBuffer.wrap(toWrap); - } - - private int getCDCRawCount() - { - return new File(DatabaseDescriptor.getCDCLogLocation()).listFiles().length; - } - - private void expectCurrentCDCState(CDCState expectedState) - { - CDCState currentState = CommitLog.instance.segmentManager.allocatingFrom().getCDCState(); - if (currentState != expectedState) - { - logger.error("expectCurrentCDCState violation! Expected state: {}. Found state: {}. Current CDC allocation: {}", - expectedState, currentState, ((CommitLogSegmentManagerCDC)CommitLog.instance.segmentManager).updateCDCTotalSize()); - Assert.fail(String.format("Received unexpected CDCState on current allocatingFrom segment. Expected: %s. Received: %s", - expectedState, currentState)); - } - } } diff --git a/test/unit/org/apache/cassandra/db/commitlog/CommitLogSegmentBackpressureTest.java b/test/unit/org/apache/cassandra/db/commitlog/CommitLogSegmentBackpressureTest.java index 6b167b2500ee..f4f703837054 100644 --- a/test/unit/org/apache/cassandra/db/commitlog/CommitLogSegmentBackpressureTest.java +++ b/test/unit/org/apache/cassandra/db/commitlog/CommitLogSegmentBackpressureTest.java @@ -111,21 +111,21 @@ public void testCompressedCommitLogBackpressure() throws Throwable dummyThread.start(); - AbstractCommitLogSegmentManager clsm = CommitLog.instance.segmentManager; + CommitLogSegmentManager clsm = CommitLog.instance.segmentManager; - Util.spinAssertEquals(3, () -> clsm.getActiveSegments().size(), 5); + Util.spinAssertEquals(3, () -> clsm.getSegmentsForUnflushedTables().size(), 5); Thread.sleep(1000); // Should only be able to create 3 segments not 7 because it blocks waiting for truncation that never comes - Assert.assertEquals(3, clsm.getActiveSegments().size()); + Assert.assertEquals(3, clsm.getSegmentsForUnflushedTables().size()); // Discard the currently active segments so allocation can continue. // Take snapshot of the list, otherwise this will also discard newly allocated segments. - new ArrayList<>(clsm.getActiveSegments()).forEach( clsm::archiveAndDiscard ); + new ArrayList<>(clsm.getSegmentsForUnflushedTables()).forEach(clsm::archiveAndDiscard ); // The allocated count should reach the limit again. - Util.spinAssertEquals(3, () -> clsm.getActiveSegments().size(), 5); + Util.spinAssertEquals(3, () -> clsm.getSegmentsForUnflushedTables().size(), 5); } finally { diff --git a/test/unit/org/apache/cassandra/db/commitlog/CommitLogTest.java b/test/unit/org/apache/cassandra/db/commitlog/CommitLogTest.java index 25e2f306254f..004945badfa4 100644 --- a/test/unit/org/apache/cassandra/db/commitlog/CommitLogTest.java +++ b/test/unit/org/apache/cassandra/db/commitlog/CommitLogTest.java @@ -314,13 +314,13 @@ public void testDontDeleteIfDirty() throws Exception .build(); CommitLog.instance.add(m2); - assertEquals(2, CommitLog.instance.segmentManager.getActiveSegments().size()); + assertEquals(2, CommitLog.instance.segmentManager.getSegmentsForUnflushedTables().size()); TableId id2 = m2.getTableIds().iterator().next(); CommitLog.instance.discardCompletedSegments(id2, CommitLogPosition.NONE, CommitLog.instance.getCurrentPosition()); // Assert we still have both our segments - assertEquals(2, CommitLog.instance.segmentManager.getActiveSegments().size()); + assertEquals(2, CommitLog.instance.segmentManager.getSegmentsForUnflushedTables().size()); } @Test @@ -340,14 +340,14 @@ public void testDeleteIfNotDirty() throws Exception CommitLog.instance.add(rm); CommitLog.instance.add(rm); - assertEquals(1, CommitLog.instance.segmentManager.getActiveSegments().size()); + assertEquals(1, CommitLog.instance.segmentManager.getSegmentsForUnflushedTables().size()); // "Flush": this won't delete anything TableId id1 = rm.getTableIds().iterator().next(); CommitLog.instance.sync(true); CommitLog.instance.discardCompletedSegments(id1, CommitLogPosition.NONE, CommitLog.instance.getCurrentPosition()); - assertEquals(1, CommitLog.instance.segmentManager.getActiveSegments().size()); + assertEquals(1, CommitLog.instance.segmentManager.getSegmentsForUnflushedTables().size()); // Adding new mutation on another CF, large enough (including CL entry overhead) that a new segment is created Mutation rm2 = new RowUpdateBuilder(cfs2.metadata(), 0, "k") @@ -359,7 +359,7 @@ public void testDeleteIfNotDirty() throws Exception CommitLog.instance.add(rm2); CommitLog.instance.add(rm2); - Collection segments = CommitLog.instance.segmentManager.getActiveSegments(); + Collection segments = CommitLog.instance.segmentManager.getSegmentsForUnflushedTables(); assertEquals(String.format("Expected 3 segments but got %d (%s)", segments.size(), getDirtyCFIds(segments)), 3, @@ -371,7 +371,7 @@ public void testDeleteIfNotDirty() throws Exception TableId id2 = rm2.getTableIds().iterator().next(); CommitLog.instance.discardCompletedSegments(id2, CommitLogPosition.NONE, CommitLog.instance.getCurrentPosition()); - segments = CommitLog.instance.segmentManager.getActiveSegments(); + segments = CommitLog.instance.segmentManager.getSegmentsForUnflushedTables(); // Assert we still have both our segment assertEquals(String.format("Expected 1 segment but got %d (%s)", segments.size(), getDirtyCFIds(segments)), @@ -617,13 +617,13 @@ public void testTruncateWithoutSnapshot() throws ExecutionException, Interrupted for (int i = 0 ; i < 5 ; i++) CommitLog.instance.add(m2); - assertEquals(2, CommitLog.instance.segmentManager.getActiveSegments().size()); + assertEquals(2, CommitLog.instance.segmentManager.getSegmentsForUnflushedTables().size()); CommitLogPosition position = CommitLog.instance.getCurrentPosition(); for (Keyspace keyspace : Keyspace.system()) for (ColumnFamilyStore syscfs : keyspace.getColumnFamilyStores()) CommitLog.instance.discardCompletedSegments(syscfs.metadata().id, CommitLogPosition.NONE, position); CommitLog.instance.discardCompletedSegments(cfs2.metadata().id, CommitLogPosition.NONE, position); - assertEquals(1, CommitLog.instance.segmentManager.getActiveSegments().size()); + assertEquals(1, CommitLog.instance.segmentManager.getSegmentsForUnflushedTables().size()); } finally { @@ -791,7 +791,7 @@ public void testUnwriteableFlushRecovery() throws ExecutionException, Interrupte { try (Closeable c = Util.markDirectoriesUnwriteable(cfs)) { - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); } catch (Throwable t) { @@ -801,7 +801,7 @@ public void testUnwriteableFlushRecovery() throws ExecutionException, Interrupte } } else - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); } } finally @@ -854,7 +854,7 @@ public void testOutOfOrderFlushRecovery(BiConsumer { try { - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); } catch (Throwable t) { diff --git a/test/unit/org/apache/cassandra/db/commitlog/CommitLogTestUtils.java b/test/unit/org/apache/cassandra/db/commitlog/CommitLogTestUtils.java new file mode 100644 index 000000000000..5c886315aad5 --- /dev/null +++ b/test/unit/org/apache/cassandra/db/commitlog/CommitLogTestUtils.java @@ -0,0 +1,254 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.db.commitlog; + +import java.io.File; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.text.MessageFormat; +import java.util.Objects; +import java.util.Random; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentLinkedQueue; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.junit.Assert; + +import org.apache.cassandra.config.DatabaseDescriptor; +import org.apache.cassandra.db.DecoratedKey; +import org.apache.cassandra.db.Mutation; + +/** + * Collection of some helper methods and classes for use in our various CommitLog Unit Tests + */ +class CommitLogTestUtils +{ + private static final Logger logger = LoggerFactory.getLogger(CommitLogTestUtils.class); + + static ByteBuffer randomizeBuffer(int size) + { + byte[] toWrap = new byte[size]; + new Random().nextBytes(toWrap); + return ByteBuffer.wrap(toWrap); + } + + static int getCDCRawCount() + { + return new File(DatabaseDescriptor.getCDCLogLocation()).listFiles().length; + } + + static void expectCurrentCDCState(CommitLogSegment.CDCState expectedState) + { + CommitLogSegment.CDCState currentState = CommitLog.instance.segmentManager.getActiveSegment().getCDCState(); + if (currentState != expectedState) + { + logger.error("expectCurrentCDCState violation! Expected state: {}. Found state: {}. Current CDC allocation: {}", + expectedState, currentState, updateCDCTotalSize(CommitLog.instance.segmentManager)); + Assert.fail(String.format("Received unexpected CDCState on current active segment. Expected: %s. Received: %s", + expectedState, currentState)); + } + } + + static long updateCDCTotalSize(CommitLogSegmentManager segmentManager) + { + return ((CommitLogSegmentAllocatorCDC)segmentManager.segmentAllocator).updateCDCTotalSize(); + } + + /** Debug method to show written files; useful when debugging specific tests. */ + static void listCommitLogFiles(String message) + { + StringBuilder result = new StringBuilder(); + result.append(String.format("%s\n", message)); + result.append("List of files in CommitLog directory:\n"); + for (File f: new File(DatabaseDescriptor.getCommitLogLocation()).listFiles()) + { + result.append(String.format("\t%s\n", f.getAbsolutePath())); + } + debugLog(result.toString()); + } + + /** Used during test debug to differentiate output visually */ + static void debugLog(String input) + { + logger.debug("\n\n**************** [TEST DEBUG] *****************\n" + + input + + System.lineSeparator() + + "***************************************************\n\n" + + System.lineSeparator()); + } + + /** + * Utility class that flags the replayer as having seen a CDC mutation and calculates offset but doesn't apply mutations. + */ + static class MutationCountingReplayer extends CommitLogReplayer + { + final ConcurrentLinkedQueue seenMutations = new ConcurrentLinkedQueue<>(); + + final ConcurrentHashMap duplicateMutations = new ConcurrentHashMap<>(); + final MutationCountingHandler mutationHandler = new MutationCountingHandler(); + + MutationCountingReplayer() throws IOException + { + super(CommitLog.instance, CommitLogPosition.NONE, null, ReplayFilter.create()); + CommitLog.instance.sync(true); + } + + void replayExistingCommitLog() throws IOException + { + for (File f: new File(DatabaseDescriptor.getCommitLogLocation()).listFiles()) + { + commitLogReader.readCommitLogSegment(mutationHandler, f, true); + } + } + + void replaySingleFile(File f) throws IOException + { + commitLogReader.readCommitLogSegment(mutationHandler, f, true); + } + + boolean hasSeenMutation(MutationIdentifier id) + { + return seenMutations.contains(id); + } + + int duplicateMutationCount(MutationIdentifier id) + { + Integer result = duplicateMutations.get(id); + return result == null ? 0 : result; + } + + boolean hasSeenDuplicateMutations() + { + return duplicateMutations.size() > 0; + } + + private class MutationCountingHandler implements CommitLogReadHandler + { + public boolean shouldSkipSegmentOnError(CommitLogReadException exception) + { + return false; + } + + public void handleUnrecoverableError(CommitLogReadException exception) + { + Assert.fail(MessageFormat.format("Got unrecoverable error during test: {0}", exception.getMessage())); + } + + public void handleMutation(Mutation m, int size, int entryLocation, CommitLogDescriptor desc) + { + MutationIdentifier id = new MutationIdentifier(m.getKeyspaceName(), m.key(), desc.id, size, entryLocation); + + if (m.trackedByCDC()) + sawCDCMutation = true; + + if (seenMutations.contains(id)) + { + Integer pv = duplicateMutations.get(id); + if (pv == null) + pv = 0; + duplicateMutations.put(id, pv + 1); + } + else + { + seenMutations.add(id); + } + } + } + + public void reset() + { + seenMutations.clear(); + duplicateMutations.clear(); + sawCDCMutation = false; + } + + /** + * Helper class that allows us to uniquely identify a mutation at least within a single instance of a running node. + */ + static class MutationIdentifier + { + final long segmentId; + final int size; + final int location; + final String keyspaceName; + final DecoratedKey decoratedKey; + + MutationIdentifier(String keyspaceName, DecoratedKey key, long segmentId, int size, int location) + { + this.keyspaceName = keyspaceName; + this.decoratedKey = key; + this.segmentId = segmentId; + this.size = size; + this.location = location; + } + + @Override + public boolean equals(Object o) + { + if (o == this) + return true; + + if (!(o instanceof MutationIdentifier)) + return false; + + MutationIdentifier other = (MutationIdentifier) o; + + return other.size == this.size && + other.location == this.location && + other.segmentId == this.segmentId && + other.keyspaceName.equals(this.keyspaceName) && + other.decoratedKey.equals(this.decoratedKey); + } + + @Override + public int hashCode() + { + return Objects.hash(size, location, segmentId, keyspaceName, decoratedKey); + } + + @Override + public String toString() + { + return new StringBuilder() + .append("sId: ").append(segmentId).append(", ") + .append(" size: ").append(size).append(", ") + .append(" loc: ").append(location).append(", ") + .append(" ks: ").append(keyspaceName).append(", ") + .append(" dk: ").append(decoratedKey) + .toString(); + } + } + } + + static class NoopMutationHandler implements CommitLogReadHandler + { + public boolean shouldSkipSegmentOnError(CommitLogReadException exception) throws IOException + { + return false; + } + + public void handleUnrecoverableError(CommitLogReadException exception) throws IOException + { } + + public void handleMutation(Mutation m, int size, int entryLocation, CommitLogDescriptor desc) + { } + } +} diff --git a/test/unit/org/apache/cassandra/db/commitlog/ResumableCommitLogReaderTest.java b/test/unit/org/apache/cassandra/db/commitlog/ResumableCommitLogReaderTest.java new file mode 100644 index 000000000000..fae154ec3653 --- /dev/null +++ b/test/unit/org/apache/cassandra/db/commitlog/ResumableCommitLogReaderTest.java @@ -0,0 +1,461 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.db.commitlog; + +import java.io.BufferedReader; +import java.io.File; +import java.io.FileReader; +import java.io.FileWriter; +import java.io.IOException; +import java.util.ArrayList; + +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.cassandra.config.DatabaseDescriptor; +import org.apache.cassandra.cql3.CQLTester; + +/** + * Tests various alignments, offsets, and operations of the {@link ResumableCommitLogReader} + */ +public class ResumableCommitLogReaderTest extends CQLTester +{ + private CommitLogSegment populatedSegment; + private CommitLogTestUtils.MutationCountingReplayer testReplayer; + + @Before + public void setUpTest() throws Throwable + { + CommitLog.instance.resetUnsafe(true); + + testReplayer = new CommitLogTestUtils.MutationCountingReplayer(); + populateReferenceData(true); + + // Should have well more than 3 segments to work with on subsequent tests. + Assert.assertTrue(CommitLog.instance.segmentManager.getSegmentsForUnflushedTables().size() > 3); + + // And always reset which file we're using as including CDC Mutations, since things may change between tests + CommitLogTestUtils.MutationCountingReplayer testReplayer = new CommitLogTestUtils.MutationCountingReplayer(); + for (CommitLogSegment cls : CommitLog.instance.segmentManager.getSegmentsForUnflushedTables()) + { + testReplayer.replaySingleFile(cls.logFile); + if (testReplayer.sawCDCMutation) + { + populatedSegment = cls; + return; + } + } + throw new RuntimeException("No mutations seen in passed in collection."); + } + + /** + * Expect operation as though non-resumable, read file to end and complete. + */ + @Test + public void testNonResumedGeneralCase() throws Throwable + { + testReplayer.replaySingleFile(populatedSegment.logFile); + + Assert.assertTrue("Did not see any CDC enabled mutations.", testReplayer.sawCDCMutation); + Assert.assertFalse("Saw a duplicate mutation while replaying a single file. This... shouldn't happen.", + testReplayer.hasSeenDuplicateMutations()); + } + + /** + * Confirm our duplicate mutation testing infrastructure is working. + */ + @Test + public void testDuplicateCheckLogic() throws Throwable + { + testReplayer.replaySingleFile(populatedSegment.logFile); + + Assert.assertTrue("Did not see any CDC enabled mutations.", testReplayer.sawCDCMutation); + Assert.assertFalse("Saw a duplicate mutation while replaying a single file. This... shouldn't happen.", + testReplayer.hasSeenDuplicateMutations()); + + testReplayer.replaySingleFile(populatedSegment.logFile); + Assert.assertTrue("Expected to see duplicate mutations on 2nd replay of file.", testReplayer.hasSeenDuplicateMutations()); + } + + /** + * Expect operation to pick up where left off and feed total mutation # consistent w/mutations in CL, so we do a + * single normal full replay, then a 2 step replay to ensure the total mutation count is as expected. + */ + @Test + public void testSingleResumeCase() throws Throwable + { + int expectedCount = getExpectedMutationCount(populatedSegment.logFile); + Assert.assertTrue("Failed to successfully perform a start to finish CL read.", expectedCount != 0); + + try(ResumableCommitLogReader rr = new ResumableCommitLogReader(populatedSegment.logFile, + testReplayer.mutationHandler, + CommitLogPosition.NONE, + CommitLogReader.ALL_MUTATIONS, + true)) + { + // This runs the risk of being flaky, since if we don't have CDC mutations in the first or second half, + // this test will fail out. This has been 100% stable at the .5 barrier both non and compressed, but keep in mind. + rr.readPartial((int)(rr.rawReader.length() * .5)); + Assert.assertFalse("Resumable Reader got constructed badly somehow.", rr.isClosed); + Assert.assertNotNull("Resumable Reader doesn't have a RAR cached in it as expected.", rr.rawReader); + + // Confirm we didn't just parse everything in the first part + Assert.assertNotEquals(expectedCount, testReplayer.seenMutations.size()); + + int interimCount = testReplayer.seenMutations.size(); + Assert.assertTrue("Failed on initial partial replay", interimCount != 0); + + Assert.assertFalse("Expected reader to be open still.", rr.isClosed()); + Assert.assertTrue("Interim replay should have played back less than a full replay. Check logs.", interimCount < expectedCount); + + rr.readPartial(CommitLogReader.READ_TO_END_OF_FILE); + Assert.assertEquals("Expected resumable read to give same # mutations as non but did not.", + testReplayer.seenMutations.size(), + expectedCount); + } + } + + /** Test multiple resumes w/end not matching SyncSegment offsets. */ + @Test + public void testMultipleResumeNonAligned() throws Throwable + { + int expectedCount = getExpectedMutationCount(populatedSegment.logFile); + + try(ResumableCommitLogReader rr = new ResumableCommitLogReader(populatedSegment.logFile, + testReplayer.mutationHandler, + CommitLogPosition.NONE, + CommitLogReader.ALL_MUTATIONS, + true)) + { + rr.readPartial(1024); + + // Sentinel to keep from locking the test. Since we are misusing the API by not sending the end sentinel, + // this checks to see if that kind of "bad offset overflow" gives us both a) a stable API, and b) the right + // parsed results from our end file. + int limit = 50; + + // throw some strange offsets at this and make sure it's robust to them, non-multiple of SyncSegment end + int offset = 1024 * 512; + while (limit > 0 && !rr.isClosed() && !rr.readToExhaustion) + { + rr.readPartial(offset); + offset += 1024 * 512; + --limit; + } + rr.readToCompletion(); + } + Assert.assertEquals("Expected non-aligned resumable read to give same # mutations as non but did not.", + expectedCount, + testReplayer.seenMutations.size()); + } + + /** Ensure that resumable readers w/offsets at sync segment boundaries don't blow up logic */ + @Test + public void testResumingAtAlignedOffsets() throws Throwable + { + int expectedCount = getExpectedMutationCount(populatedSegment.logFile); + + ArrayList segmentBoundaries = new ArrayList<>(); + // First, we want to get a list of all the segment boundaries out of the SegmentIterator + try(ResumableCommitLogReader fr = new ResumableCommitLogReader(populatedSegment.logFile, + testReplayer.mutationHandler, + CommitLogPosition.NONE, + CommitLogReader.ALL_MUTATIONS, + true)) + { + fr.offsetLimit = CommitLogReader.READ_TO_END_OF_FILE; + CommitLogSegmentReader lsr = new CommitLogSegmentReader(fr); + for (CommitLogSegmentReader.SyncSegment ss : lsr) + { + segmentBoundaries.add(ss.endPosition); + } + Assert.assertTrue(segmentBoundaries.size() > 0); + } + Assert.assertNotEquals(-1, expectedCount); + testReplayer.reset(); + + // And now we iterate through the file using those boundaries, ensuring it works in our resumable reader. + try(ResumableCommitLogReader rr = new ResumableCommitLogReader(populatedSegment.logFile, + testReplayer.mutationHandler, + CommitLogPosition.NONE, + CommitLogReader.ALL_MUTATIONS, + true)) + { + // Confirm first read doesn't exhaust so the test is actually testing something. + rr.readPartial(segmentBoundaries.get(0)); + Assert.assertFalse(rr.readToExhaustion); + Assert.assertFalse(rr.isClosed); + + for (Integer offset : segmentBoundaries) + { + rr.readPartial(offset); + } + rr.readToCompletion(); + + Assert.assertEquals("Reading based on segment offsets produced unexpected results.", + expectedCount, + testReplayer.seenMutations.size()); + } + } + + /** Expect operation to pick up and feed total mutation # consistent w/mutations in CL, no errors. */ + @Test + public void testResumeAtFileEndOffset() throws Throwable + { + int expectedCount = getExpectedMutationCount(populatedSegment.logFile); + + try(ResumableCommitLogReader rr = new ResumableCommitLogReader(populatedSegment.logFile, + testReplayer.mutationHandler, + CommitLogPosition.NONE, + CommitLogReader.ALL_MUTATIONS, + true)) + { + rr.readPartial((int) populatedSegment.logFile.length()); + Assert.assertEquals(expectedCount, testReplayer.seenMutations.size()); + } + } + + /** + * If the offset provided by the user is past the end of the file itself, we expect a graceful read when using the logic + * on a resumable reader as we read to end of file. + */ + @Test + public void testOffsetPastEnd() throws Throwable + { + int expectedCount = getExpectedMutationCount(populatedSegment.logFile); + + try(ResumableCommitLogReader rr = new ResumableCommitLogReader(populatedSegment.logFile, + testReplayer.mutationHandler, + CommitLogPosition.NONE, + CommitLogReader.ALL_MUTATIONS, + true)) + { + rr.readPartial(Integer.MAX_VALUE); + Assert.assertEquals(expectedCount, testReplayer.seenMutations.size()); + } + } + + /** If the offset provided by the user is < 0, we expect a no-op on read followed by the ability to resume and read */ + @Test + public void testNegativeOffset() throws Throwable + { + int expectedCount = getExpectedMutationCount(populatedSegment.logFile); + + try(ResumableCommitLogReader rr = new ResumableCommitLogReader(populatedSegment.logFile, + testReplayer.mutationHandler, + CommitLogPosition.NONE, + CommitLogReader.ALL_MUTATIONS, + true)) + { + rr.readPartial(-1); + // Expect nothing to have been read so at beginning of file still + Assert.assertFalse(rr.readToExhaustion); + Assert.assertEquals(0, testReplayer.seenMutations.size()); + + // Then re-use the infra to read to end + rr.readPartial(CommitLogReader.READ_TO_END_OF_FILE); + Assert.assertEquals(expectedCount, testReplayer.seenMutations.size()); + Assert.assertEquals(0, testReplayer.duplicateMutations.size()); + } + } + + /** + * Confirm we gracefully handle cases where people may put in an offset that regresses what we're reading. Since the + * RAR and iteration should be unidirectional, we should still not see duplicates. + */ + @Test + public void testRepeatOffsets() throws Throwable + { + int expectedCount = getExpectedMutationCount(populatedSegment.logFile); + Assert.assertTrue(testReplayer.duplicateMutations.isEmpty()); + Assert.assertNotEquals(0, expectedCount); + + // Cache what we saw on first replay vs. newest, confirm >= all original seen + ArrayList ids = new ArrayList<>(testReplayer.seenMutations); + testReplayer.reset(); + + ArrayList newIds = null; + + try(ResumableCommitLogReader rr = new ResumableCommitLogReader(populatedSegment.logFile, + testReplayer.mutationHandler, + CommitLogPosition.NONE, + CommitLogReader.ALL_MUTATIONS, + true)) + { + rr.readPartial(1024 * 512); + int countBeforeRegression = testReplayer.seenMutations.size(); + + // Confirm no reads if we regress our offset + rr.readPartial(1024 * 256); + Assert.assertEquals(0, testReplayer.duplicateMutations.size()); + Assert.assertEquals(countBeforeRegression, testReplayer.seenMutations.size()); + + // Confirm can resume w/correct offset + rr.readPartial(CommitLogReader.READ_TO_END_OF_FILE); + + newIds = new ArrayList<>(testReplayer.seenMutations); + + // Confirm all mutation id's seen in straight read are seen by partial. + for (Object id : ids) + if (!newIds.contains(id)) + Assert.fail(String.format("Missing id in resumable replay: %s", id)); + Assert.assertTrue(testReplayer.duplicateMutations.size() > 0); + } + } + + /** + * Expect RTE if the input file is completely missing. Since these exceptions are user-facing in CL consumption in + * a CDC context, this unit test serves to calcify that UI interaction a bit and confirm we're deliberate about changing + * the type of exception we're throwing. + */ + @Test + public void testMissingInputFile() throws Throwable + { + try(ResumableCommitLogReader rr = new ResumableCommitLogReader(new File("This_should_fail.txt"), + testReplayer.mutationHandler, + CommitLogPosition.NONE, + CommitLogReader.ALL_MUTATIONS, + true)) + { } + catch (RuntimeException e) + { + if (!e.getMessage().contains("version")) + Assert.fail(); + return; + } + Assert.fail("Expected RuntimeException on creation. Did not get it."); + } + + @Test + public void testWrongFile() throws Throwable + { + File tempFile = File.createTempFile("test_file", ".tmp"); + try(FileWriter writer = new FileWriter(tempFile)) + { + writer.write("This is really not a commit log header. This will go badly."); + } + + try(ResumableCommitLogReader rr = new ResumableCommitLogReader(tempFile, + testReplayer.mutationHandler, + CommitLogPosition.NONE, + CommitLogReader.ALL_MUTATIONS, + true)) + { } + catch (RuntimeException rte) + { + if (rte.getMessage().contains("version")) + return; + } + Assert.fail("Expected RuntimeException complaining about inability to parse version from CommitLogHeader"); + } + + /** + * Uses CDC pipeline to read CDC index file + hard linked CDC File w/multiple mutations to confirm user use-case is functional. + * This test requires CDC to be enabled to run as we need the CDC Allocator to be hardlinking files, etc. + * + * This test is unique (and has surfaced multiple pain points) in that it is the only test that is processing + * a file being actively written. + */ + @Test + public void testUsingCDCOffsets() throws Throwable + { + if (!(CommitLog.instance.segmentManager.segmentAllocator instanceof CommitLogSegmentAllocatorCDC)) + return; + + CommitLogSegment activeSegment = CommitLog.instance.segmentManager.getActiveSegment(); + Assert.assertSame(activeSegment.getCDCState(), CommitLogSegment.CDCState.CONTAINS); + + File cdcSegment = activeSegment.getCDCFile(); + Assert.assertTrue(cdcSegment.exists()); + + // Confirm we have a reasonable offset now and can scan to it + int cdcOffset = parseCDCOffset(activeSegment.getCDCIndexFile()); + Assert.assertNotEquals(0, cdcOffset); + + int totalReadThroughCDC = -1; + // Start a resumable reader wrapped around our active segment + try(ResumableCommitLogReader rr = new ResumableCommitLogReader(cdcSegment, + testReplayer.mutationHandler, + CommitLogPosition.NONE, + CommitLogReader.ALL_MUTATIONS, + true)) + { + rr.readPartial(cdcOffset); + int lastMutationsRead = testReplayer.seenMutations.size(); + Assert.assertNotEquals(0, lastMutationsRead); + + byte[] buffer = new byte[1024 * 128]; + + int tries = 250; + int newCDCOffset; + while (tries > 0) + { + cdcOffset = parseCDCOffset(activeSegment.getCDCIndexFile()); + writeReferenceLines(50, buffer); + CommitLog.instance.sync(true); + + newCDCOffset = parseCDCOffset(activeSegment.getCDCIndexFile()); + // Look for 1 cdcOffset change and read it. We stop here since we can't really deterministically get + // a repeatable test on # of writes leading to # of SyncSegment's written w/compression, etc. + if (cdcOffset != newCDCOffset) + { + rr.readPartial(newCDCOffset); + break; + } + tries--; + } + Assert.assertNotEquals(0, tries); + // Confirm we read far enough to exhaust the underlying buffer in compression context. This will exercise the RAR re-alloc code. + if (DatabaseDescriptor.getCommitLogCompression() != null) + Assert.assertTrue(rr.readToExhaustion); + + // Write a chunk more data; hopefully cycles it so we have a full file, though not particularly relevant to our needs. + buffer = new byte[1024 * 512]; + writeReferenceLines(150, buffer); + CommitLog.instance.sync(true); + + rr.readToCompletion(); + totalReadThroughCDC = testReplayer.seenMutations.size(); + } + + // Get final count with one straight through read, confirm matches CDC staged + testReplayer.reset(); + int expectedSeen = getExpectedMutationCount(cdcSegment); + Assert.assertEquals(expectedSeen, totalReadThroughCDC); + } + + /** Straightforward single non-resumable replay to count mutations expected. Assumes you're looking for > 0 */ + private int getExpectedMutationCount(File file) throws IOException + { + testReplayer.replaySingleFile(file); + int expectedCount = testReplayer.seenMutations.size(); + testReplayer.reset(); + Assert.assertTrue("Failed to successfully perform a start to finish CL read.", expectedCount != 0); + return expectedCount; + } + + private int parseCDCOffset(File cdcIndexFile) throws IOException + { + try(BufferedReader br = new BufferedReader(new FileReader(cdcIndexFile))) + { + return Integer.parseInt(br.readLine()); + } + } +} diff --git a/test/unit/org/apache/cassandra/db/commitlog/SegmentReaderTest.java b/test/unit/org/apache/cassandra/db/commitlog/SegmentReaderTest.java index ce209351fa62..7fdbf6f4f09f 100644 --- a/test/unit/org/apache/cassandra/db/commitlog/SegmentReaderTest.java +++ b/test/unit/org/apache/cassandra/db/commitlog/SegmentReaderTest.java @@ -107,9 +107,9 @@ private void compressedSegmenter(ICompressor compressor) throws IOException fos.getChannel().write(compBuffer); fos.close(); - try (RandomAccessReader reader = RandomAccessReader.open(compressedFile)) + try (ResumableCommitLogReader rr = new ResumableCommitLogReader(compressedFile, new CommitLogTestUtils.NoopMutationHandler())) { - CompressedSegmenter segmenter = new CompressedSegmenter(compressor, reader); + CompressedSegmenter segmenter = new CompressedSegmenter(compressor, rr); int fileLength = (int) compressedFile.length(); SyncSegment syncSegment = segmenter.nextSegment(0, fileLength); FileDataInput fileDataInput = syncSegment.input; @@ -195,11 +195,11 @@ public void underlyingEncryptedSegmenterTest(BiFunction false); + MessagingService.instance().inboundSink.add((message) -> false); } @Before @@ -101,7 +86,7 @@ SSTableReader makeSSTable(boolean orphan) int pk = nextSSTableKey++; Set pre = cfs.getLiveSSTables(); QueryProcessor.executeInternal(String.format("INSERT INTO %s.%s (k, v) VALUES(?, ?)", ks, tbl), pk, pk); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); Set post = cfs.getLiveSSTables(); Set diff = new HashSet<>(post); diff.removeAll(pre); diff --git a/test/unit/org/apache/cassandra/db/compaction/ActiveCompactionsTest.java b/test/unit/org/apache/cassandra/db/compaction/ActiveCompactionsTest.java index be5e7df7a78f..445ff573e6f3 100644 --- a/test/unit/org/apache/cassandra/db/compaction/ActiveCompactionsTest.java +++ b/test/unit/org/apache/cassandra/db/compaction/ActiveCompactionsTest.java @@ -18,7 +18,6 @@ package org.apache.cassandra.db.compaction; -import java.util.ArrayList; import java.util.Collections; import java.util.Map; import java.util.Set; @@ -60,7 +59,7 @@ public void testSecondaryIndexTracking() throws Throwable for (int i = 0; i < 5; i++) { execute("INSERT INTO %s (pk, ck, a, b) VALUES ("+i+", 2, 3, 4)"); - getCurrentColumnFamilyStore().forceBlockingFlush(); + getCurrentColumnFamilyStore().forceBlockingFlushToSSTable(); } Index idx = getCurrentColumnFamilyStore().indexManager.getIndexByName(idxName); @@ -83,7 +82,7 @@ public void testIndexSummaryRedistributionTracking() throws Throwable for (int i = 0; i < 5; i++) { execute("INSERT INTO %s (pk, ck, a, b) VALUES ("+i+", 2, 3, 4)"); - getCurrentColumnFamilyStore().forceBlockingFlush(); + getCurrentColumnFamilyStore().forceBlockingFlushToSSTable(); } Set sstables = getCurrentColumnFamilyStore().getLiveSSTables(); try (LifecycleTransaction txn = getCurrentColumnFamilyStore().getTracker().tryModify(sstables, OperationType.INDEX_SUMMARY)) @@ -108,7 +107,7 @@ public void testViewBuildTracking() throws Throwable for (int i = 0; i < 5; i++) { execute("INSERT INTO %s (k1, c1, val) VALUES ("+i+", 2, 3)"); - getCurrentColumnFamilyStore().forceBlockingFlush(); + getCurrentColumnFamilyStore().forceBlockingFlushToSSTable(); } execute(String.format("CREATE MATERIALIZED VIEW %s.view1 AS SELECT k1, c1, val FROM %s.%s WHERE k1 IS NOT NULL AND c1 IS NOT NULL AND val IS NOT NULL PRIMARY KEY (val, k1, c1)", keyspace(), keyspace(), currentTable())); View view = Iterables.getOnlyElement(getCurrentColumnFamilyStore().viewManager); @@ -132,7 +131,7 @@ public void testScrubOne() throws Throwable for (int i = 0; i < 5; i++) { execute("INSERT INTO %s (pk, ck, a, b) VALUES (" + i + ", 2, 3, 4)"); - getCurrentColumnFamilyStore().forceBlockingFlush(); + getCurrentColumnFamilyStore().forceBlockingFlushToSSTable(); } SSTableReader sstable = Iterables.getFirst(getCurrentColumnFamilyStore().getLiveSSTables(), null); @@ -157,7 +156,7 @@ public void testVerifyOne() throws Throwable for (int i = 0; i < 5; i++) { execute("INSERT INTO %s (pk, ck, a, b) VALUES (" + i + ", 2, 3, 4)"); - getCurrentColumnFamilyStore().forceBlockingFlush(); + getCurrentColumnFamilyStore().forceBlockingFlushToSSTable(); } SSTableReader sstable = Iterables.getFirst(getCurrentColumnFamilyStore().getLiveSSTables(), null); diff --git a/test/unit/org/apache/cassandra/db/compaction/AntiCompactionBytemanTest.java b/test/unit/org/apache/cassandra/db/compaction/AntiCompactionBytemanTest.java index 38d2607d2f7d..499002c9dcb5 100644 --- a/test/unit/org/apache/cassandra/db/compaction/AntiCompactionBytemanTest.java +++ b/test/unit/org/apache/cassandra/db/compaction/AntiCompactionBytemanTest.java @@ -67,7 +67,7 @@ public void testRedundantTransitions() throws Throwable execute("insert into %s (id, i) values (1, 1)"); execute("insert into %s (id, i) values (2, 1)"); execute("insert into %s (id, i) values (3, 1)"); - getCurrentColumnFamilyStore().forceBlockingFlush(); + getCurrentColumnFamilyStore().forceBlockingFlushToSSTable(); UntypedResultSet res = execute("select token(id) as tok from %s"); Iterator it = res.iterator(); List tokens = new ArrayList<>(); diff --git a/test/unit/org/apache/cassandra/db/compaction/AntiCompactionTest.java b/test/unit/org/apache/cassandra/db/compaction/AntiCompactionTest.java index b2618e54f7ff..8f5b1608765c 100644 --- a/test/unit/org/apache/cassandra/db/compaction/AntiCompactionTest.java +++ b/test/unit/org/apache/cassandra/db/compaction/AntiCompactionTest.java @@ -28,7 +28,6 @@ import java.util.Set; import java.util.UUID; import java.util.function.Predicate; -import java.util.stream.Collectors; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Iterables; @@ -299,7 +298,7 @@ public void generateSStable(ColumnFamilyStore store, String Suffix) .build() .applyUnsafe(); } - store.forceBlockingFlush(); + store.forceBlockingFlushToSSTable(); } @Test @@ -442,7 +441,7 @@ private ColumnFamilyStore prepareColumnFamilyStore() .build() .applyUnsafe(); } - store.forceBlockingFlush(); + store.forceBlockingFlushToSSTable(); return store; } diff --git a/test/unit/org/apache/cassandra/db/compaction/BlacklistingCompactionsTest.java b/test/unit/org/apache/cassandra/db/compaction/BlacklistingCompactionsTest.java index e0f24f28d68b..2d3c886a31cc 100644 --- a/test/unit/org/apache/cassandra/db/compaction/BlacklistingCompactionsTest.java +++ b/test/unit/org/apache/cassandra/db/compaction/BlacklistingCompactionsTest.java @@ -158,7 +158,7 @@ private void testBlacklisting(String tableName) throws Exception maxTimestampExpected = Math.max(timestamp, maxTimestampExpected); inserted.add(key); } - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); CompactionsTest.assertMaxTimestamp(cfs, maxTimestampExpected); assertEquals(inserted.toString(), inserted.size(), Util.getAll(Util.cmd(cfs).build()).size()); } diff --git a/test/unit/org/apache/cassandra/db/compaction/CompactionAwareWriterTest.java b/test/unit/org/apache/cassandra/db/compaction/CompactionAwareWriterTest.java index 68936f55427b..766eb4e8850a 100644 --- a/test/unit/org/apache/cassandra/db/compaction/CompactionAwareWriterTest.java +++ b/test/unit/org/apache/cassandra/db/compaction/CompactionAwareWriterTest.java @@ -195,7 +195,7 @@ private void populate(int count) throws Throwable execute(String.format("INSERT INTO %s.%s(k, t, v) VALUES (?, ?, ?)", KEYSPACE, TABLE), i, j, b); ColumnFamilyStore cfs = getColumnFamilyStore(); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); if (cfs.getLiveSSTables().size() > 1) { // we want just one big sstable to avoid doing actual compaction in compact() above diff --git a/test/unit/org/apache/cassandra/db/compaction/CompactionControllerTest.java b/test/unit/org/apache/cassandra/db/compaction/CompactionControllerTest.java index 0ab714acfc34..0d39a625d91e 100644 --- a/test/unit/org/apache/cassandra/db/compaction/CompactionControllerTest.java +++ b/test/unit/org/apache/cassandra/db/compaction/CompactionControllerTest.java @@ -21,7 +21,6 @@ import java.nio.ByteBuffer; import java.util.Set; import java.util.function.LongPredicate; -import java.util.function.Predicate; import com.google.common.collect.Sets; import org.junit.BeforeClass; @@ -95,7 +94,7 @@ public void testMaxPurgeableTimestamp() { assertPurgeBoundary(controller.getPurgeEvaluator(key), timestamp1); //memtable only - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); assertTrue(controller.getPurgeEvaluator(key).test(Long.MAX_VALUE)); //no memtables and no sstables } @@ -103,7 +102,7 @@ public void testMaxPurgeableTimestamp() // create another sstable applyMutation(cfs.metadata(), key, timestamp2); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); // check max purgeable timestamp when compacting the first sstable with and without a memtable try (CompactionController controller = new CompactionController(cfs, compacting, 0)) @@ -116,7 +115,7 @@ public void testMaxPurgeableTimestamp() } // check max purgeable timestamp again without any sstables but with different insertion orders on the memtable - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); //newest to oldest try (CompactionController controller = new CompactionController(cfs, null, 0)) @@ -128,7 +127,7 @@ public void testMaxPurgeableTimestamp() assertPurgeBoundary(controller.getPurgeEvaluator(key), timestamp3); //memtable only } - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); //oldest to newest try (CompactionController controller = new CompactionController(cfs, null, 0)) @@ -156,14 +155,14 @@ public void testGetFullyExpiredSSTables() // create sstable with tombstone that should be expired in no older timestamps applyDeleteMutation(cfs.metadata(), key, timestamp2); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); // first sstable with tombstone is compacting Set compacting = Sets.newHashSet(cfs.getLiveSSTables()); // create another sstable with more recent timestamp applyMutation(cfs.metadata(), key, timestamp1); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); // second sstable is overlapping Set overlapping = Sets.difference(Sets.newHashSet(cfs.getLiveSSTables()), compacting); diff --git a/test/unit/org/apache/cassandra/db/compaction/CompactionStrategyManagerTest.java b/test/unit/org/apache/cassandra/db/compaction/CompactionStrategyManagerTest.java index 73e6852eb178..ddabcaf82e22 100644 --- a/test/unit/org/apache/cassandra/db/compaction/CompactionStrategyManagerTest.java +++ b/test/unit/org/apache/cassandra/db/compaction/CompactionStrategyManagerTest.java @@ -511,7 +511,7 @@ private static SSTableReader createSSTableWithKey(String keyspace, String table, .build() .applyUnsafe(); Set before = cfs.getLiveSSTables(); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); Set after = cfs.getLiveSSTables(); return Iterables.getOnlyElement(Sets.difference(after, before)); } diff --git a/test/unit/org/apache/cassandra/db/compaction/CompactionTaskTest.java b/test/unit/org/apache/cassandra/db/compaction/CompactionTaskTest.java index af74603fd205..a9c178276e18 100644 --- a/test/unit/org/apache/cassandra/db/compaction/CompactionTaskTest.java +++ b/test/unit/org/apache/cassandra/db/compaction/CompactionTaskTest.java @@ -71,10 +71,10 @@ public void compactionInterruption() throws Exception cfs.getCompactionStrategyManager().disable(); QueryProcessor.executeInternal("INSERT INTO ks.tbl (k, v) VALUES (1, 1);"); QueryProcessor.executeInternal("INSERT INTO ks.tbl (k, v) VALUES (2, 2);"); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); QueryProcessor.executeInternal("INSERT INTO ks.tbl (k, v) VALUES (3, 3);"); QueryProcessor.executeInternal("INSERT INTO ks.tbl (k, v) VALUES (4, 4);"); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); Set sstables = cfs.getLiveSSTables(); Assert.assertEquals(2, sstables.size()); @@ -111,13 +111,13 @@ public void mixedSSTableFailure() throws Exception { cfs.getCompactionStrategyManager().disable(); QueryProcessor.executeInternal("INSERT INTO ks.tbl (k, v) VALUES (1, 1);"); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); QueryProcessor.executeInternal("INSERT INTO ks.tbl (k, v) VALUES (2, 2);"); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); QueryProcessor.executeInternal("INSERT INTO ks.tbl (k, v) VALUES (3, 3);"); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); QueryProcessor.executeInternal("INSERT INTO ks.tbl (k, v) VALUES (4, 4);"); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); List sstables = new ArrayList<>(cfs.getLiveSSTables()); Assert.assertEquals(4, sstables.size()); diff --git a/test/unit/org/apache/cassandra/db/compaction/CompactionsBytemanTest.java b/test/unit/org/apache/cassandra/db/compaction/CompactionsBytemanTest.java index 2519389b7bfa..75268082accb 100644 --- a/test/unit/org/apache/cassandra/db/compaction/CompactionsBytemanTest.java +++ b/test/unit/org/apache/cassandra/db/compaction/CompactionsBytemanTest.java @@ -19,18 +19,28 @@ package org.apache.cassandra.db.compaction; import java.util.concurrent.TimeUnit; +import java.util.Collection; +import java.util.Collections; +import java.util.function.Consumer; +import java.util.stream.Collectors; import org.junit.Test; import org.junit.runner.RunWith; import org.apache.cassandra.cql3.CQLTester; import org.apache.cassandra.db.ColumnFamilyStore; +import org.apache.cassandra.db.Keyspace; +import org.apache.cassandra.dht.Range; +import org.apache.cassandra.dht.Token; +import org.apache.cassandra.io.sstable.Descriptor; import org.apache.cassandra.utils.FBUtilities; import org.jboss.byteman.contrib.bmunit.BMRule; import org.jboss.byteman.contrib.bmunit.BMRules; import org.jboss.byteman.contrib.bmunit.BMUnitRunner; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; @RunWith(BMUnitRunner.class) public class CompactionsBytemanTest extends CQLTester @@ -118,7 +128,7 @@ public void testCompactingCFCounting() throws Throwable execute("INSERT INTO %s (k, c, v) VALUES (?, ?, ?)", 0, 1, 1); assertEquals(0, CompactionManager.instance.compactingCF.count(cfs)); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); FBUtilities.waitOnFutures(CompactionManager.instance.submitBackground(cfs)); assertEquals(0, CompactionManager.instance.compactingCF.count(cfs)); @@ -135,10 +145,82 @@ private void createPossiblyExpiredSSTable(final ColumnFamilyStore cfs, final boo { execute("INSERT INTO %s (id, val) values (2, 'immortal')"); } - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); } private void createLowGCGraceTable(){ createTable("CREATE TABLE %s (id int PRIMARY KEY, val text) with compaction = {'class':'SizeTieredCompactionStrategy', 'enabled': 'false'} AND gc_grace_seconds=0"); } -} \ No newline at end of file + + @Test + @BMRule(name = "Stop all compactions", + targetClass = "CompactionTask", + targetMethod = "runMayThrow", + targetLocation = "AT INVOKE getCompactionAwareWriter", + action = "$ci.stop()") + public void testStopUserDefinedCompactionRepaired() throws Throwable + { + testStopCompactionRepaired((cfs) -> { + Collection files = cfs.getLiveSSTables().stream().map(s -> s.descriptor).collect(Collectors.toList()); + FBUtilities.waitOnFuture(CompactionManager.instance.submitUserDefined(cfs, files, CompactionManager.NO_GC)); + }); + } + + @Test + @BMRule(name = "Stop all compactions", + targetClass = "CompactionTask", + targetMethod = "runMayThrow", + targetLocation = "AT INVOKE getCompactionAwareWriter", + action = "$ci.stop()") + public void testStopSubRangeCompactionRepaired() throws Throwable + { + testStopCompactionRepaired((cfs) -> { + Collection> ranges = Collections.singleton(new Range<>(cfs.getPartitioner().getMinimumToken(), + cfs.getPartitioner().getMaximumToken())); + CompactionManager.instance.forceCompactionForTokenRange(cfs, ranges); + }); + } + + public void testStopCompactionRepaired(Consumer compactionRunner) throws Throwable + { + String table = createTable("CREATE TABLE %s (k INT, c INT, v INT, PRIMARY KEY (k, c))"); + ColumnFamilyStore cfs = Keyspace.open(CQLTester.KEYSPACE).getColumnFamilyStore(table); + cfs.disableAutoCompaction(); + for (int i = 0; i < 5; i++) + { + for (int j = 0; j < 10; j++) + { + execute("insert into %s (k, c, v) values (?, ?, ?)", i, j, i*j); + } + cfs.forceBlockingFlushToSSTable(); + } + cfs.getCompactionStrategyManager().mutateRepaired(cfs.getLiveSSTables(), System.currentTimeMillis(), null, false); + for (int i = 0; i < 5; i++) + { + for (int j = 0; j < 10; j++) + { + execute("insert into %s (k, c, v) values (?, ?, ?)", i, j, i*j); + } + cfs.forceBlockingFlushToSSTable(); + } + + assertTrue(cfs.getTracker().getCompacting().isEmpty()); + assertTrue(CompactionManager.instance.active.getCompactions().stream().noneMatch(h -> h.getCompactionInfo().getTableMetadata().equals(cfs.metadata))); + + try + { + compactionRunner.accept(cfs); + fail("compaction should fail"); + } + catch (RuntimeException t) + { + if (!(t.getCause().getCause() instanceof CompactionInterruptedException)) + throw t; + //expected + } + + assertTrue(cfs.getTracker().getCompacting().isEmpty()); + assertTrue(CompactionManager.instance.active.getCompactions().stream().noneMatch(h -> h.getCompactionInfo().getTableMetadata().equals(cfs.metadata))); + + } +} diff --git a/test/unit/org/apache/cassandra/db/compaction/CompactionsCQLTest.java b/test/unit/org/apache/cassandra/db/compaction/CompactionsCQLTest.java index b003721c06ec..3cdb3a0aff4e 100644 --- a/test/unit/org/apache/cassandra/db/compaction/CompactionsCQLTest.java +++ b/test/unit/org/apache/cassandra/db/compaction/CompactionsCQLTest.java @@ -290,7 +290,7 @@ public void testCompactionInvalidRTs() throws Throwable RangeTombstone rt = new RangeTombstone(Slice.ALL, new DeletionTime(System.currentTimeMillis(), -1)); RowUpdateBuilder rub = new RowUpdateBuilder(getCurrentColumnFamilyStore().metadata(), System.currentTimeMillis() * 1000, 22).clustering(33).addRangeTombstone(rt); rub.build().apply(); - getCurrentColumnFamilyStore().forceBlockingFlush(); + getCurrentColumnFamilyStore().forceBlockingFlushToSSTable(); compactAndValidate(); readAndValidate(true); readAndValidate(false); @@ -304,7 +304,7 @@ public void testCompactionInvalidTombstone() throws Throwable // write a standard tombstone with negative local deletion time (LDTs are not set by user and should not be negative): RowUpdateBuilder rub = new RowUpdateBuilder(getCurrentColumnFamilyStore().metadata(), -1, System.currentTimeMillis() * 1000, 22).clustering(33).delete("b"); rub.build().apply(); - getCurrentColumnFamilyStore().forceBlockingFlush(); + getCurrentColumnFamilyStore().forceBlockingFlushToSSTable(); compactAndValidate(); readAndValidate(true); readAndValidate(false); @@ -318,7 +318,7 @@ public void testCompactionInvalidPartitionDeletion() throws Throwable // write a partition deletion with negative local deletion time (LDTs are not set by user and should not be negative):: PartitionUpdate pu = PartitionUpdate.simpleBuilder(getCurrentColumnFamilyStore().metadata(), 22).nowInSec(-1).delete().build(); new Mutation(pu).apply(); - getCurrentColumnFamilyStore().forceBlockingFlush(); + getCurrentColumnFamilyStore().forceBlockingFlushToSSTable(); compactAndValidate(); readAndValidate(true); readAndValidate(false); @@ -331,7 +331,7 @@ public void testCompactionInvalidRowDeletion() throws Throwable prepare(); // write a row deletion with negative local deletion time (LDTs are not set by user and should not be negative): RowUpdateBuilder.deleteRowAt(getCurrentColumnFamilyStore().metadata(), System.currentTimeMillis() * 1000, -1, 22, 33).apply(); - getCurrentColumnFamilyStore().forceBlockingFlush(); + getCurrentColumnFamilyStore().forceBlockingFlushToSSTable(); compactAndValidate(); readAndValidate(true); readAndValidate(false); @@ -353,7 +353,7 @@ public void testIndexedReaderRowDeletion() throws Throwable DatabaseDescriptor.setColumnIndexSize(1024); prepareWide(); RowUpdateBuilder.deleteRowAt(getCurrentColumnFamilyStore().metadata(), System.currentTimeMillis() * 1000, -1, 22, 33).apply(); - getCurrentColumnFamilyStore().forceBlockingFlush(); + getCurrentColumnFamilyStore().forceBlockingFlushToSSTable(); readAndValidate(true); readAndValidate(false); DatabaseDescriptor.setColumnIndexSize(maxSizePre); @@ -369,7 +369,7 @@ public void testIndexedReaderTombstone() throws Throwable prepareWide(); RowUpdateBuilder rub = new RowUpdateBuilder(getCurrentColumnFamilyStore().metadata(), -1, System.currentTimeMillis() * 1000, 22).clustering(33).delete("b"); rub.build().apply(); - getCurrentColumnFamilyStore().forceBlockingFlush(); + getCurrentColumnFamilyStore().forceBlockingFlushToSSTable(); readAndValidate(true); readAndValidate(false); DatabaseDescriptor.setColumnIndexSize(maxSizePre); @@ -386,7 +386,7 @@ public void testIndexedReaderRT() throws Throwable RangeTombstone rt = new RangeTombstone(Slice.ALL, new DeletionTime(System.currentTimeMillis(), -1)); RowUpdateBuilder rub = new RowUpdateBuilder(getCurrentColumnFamilyStore().metadata(), System.currentTimeMillis() * 1000, 22).clustering(33).addRangeTombstone(rt); rub.build().apply(); - getCurrentColumnFamilyStore().forceBlockingFlush(); + getCurrentColumnFamilyStore().forceBlockingFlushToSSTable(); readAndValidate(true); readAndValidate(false); DatabaseDescriptor.setColumnIndexSize(maxSizePre); @@ -408,7 +408,7 @@ public void testLCSThresholdParams() throws Throwable { execute("insert into %s (id, id2, t) values (?, ?, ?)", i, j, value); } - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); } assertEquals(50, cfs.getLiveSSTables().size()); LeveledCompactionStrategy lcs = (LeveledCompactionStrategy) cfs.getCompactionStrategyManager().getUnrepairedUnsafe().first(); @@ -425,7 +425,7 @@ public void testSTCSinL0() throws Throwable ColumnFamilyStore cfs = getCurrentColumnFamilyStore(); cfs.disableAutoCompaction(); execute("insert into %s (id, id2, t) values (?, ?, ?)", 1,1,"L1"); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); cfs.forceMajorCompaction(); SSTableReader l1sstable = cfs.getLiveSSTables().iterator().next(); assertEquals(1, l1sstable.getSSTableLevel()); @@ -439,7 +439,7 @@ public void testSTCSinL0() throws Throwable { execute("insert into %s (id, id2, t) values (?, ?, ?)", i, j, value); } - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); } assertEquals(51, cfs.getLiveSSTables().size()); @@ -533,7 +533,7 @@ public void testPerCFSNeverPurgeTombstonesHelper(boolean deletedCell) throws Thr { execute("INSERT INTO %s (id, b) VALUES (?, ?)", i, String.valueOf(i)); } - getCurrentColumnFamilyStore().forceBlockingFlush(); + getCurrentColumnFamilyStore().forceBlockingFlushToSSTable(); assertTombstones(getCurrentColumnFamilyStore().getLiveSSTables().iterator().next(), false); if (deletedCell) @@ -541,7 +541,7 @@ public void testPerCFSNeverPurgeTombstonesHelper(boolean deletedCell) throws Thr else execute("DELETE FROM %s WHERE id = ?", 50); getCurrentColumnFamilyStore().setNeverPurgeTombstones(false); - getCurrentColumnFamilyStore().forceBlockingFlush(); + getCurrentColumnFamilyStore().forceBlockingFlushToSSTable(); Thread.sleep(2000); // wait for gcgs to pass getCurrentColumnFamilyStore().forceMajorCompaction(); assertTombstones(getCurrentColumnFamilyStore().getLiveSSTables().iterator().next(), false); @@ -550,7 +550,7 @@ public void testPerCFSNeverPurgeTombstonesHelper(boolean deletedCell) throws Thr else execute("DELETE FROM %s WHERE id = ?", 44); getCurrentColumnFamilyStore().setNeverPurgeTombstones(true); - getCurrentColumnFamilyStore().forceBlockingFlush(); + getCurrentColumnFamilyStore().forceBlockingFlushToSSTable(); Thread.sleep(1100); getCurrentColumnFamilyStore().forceMajorCompaction(); assertTombstones(getCurrentColumnFamilyStore().getLiveSSTables().iterator().next(), true); diff --git a/test/unit/org/apache/cassandra/db/compaction/CompactionsPurgeTest.java b/test/unit/org/apache/cassandra/db/compaction/CompactionsPurgeTest.java index dcd5270f9657..203c94de8777 100644 --- a/test/unit/org/apache/cassandra/db/compaction/CompactionsPurgeTest.java +++ b/test/unit/org/apache/cassandra/db/compaction/CompactionsPurgeTest.java @@ -19,9 +19,9 @@ package org.apache.cassandra.db.compaction; import java.util.Collection; -import java.util.List; import java.util.concurrent.ExecutionException; +import com.google.common.collect.Iterables; import org.junit.BeforeClass; import org.junit.Test; @@ -102,14 +102,14 @@ public void testMajorCompactionPurge() .build().applyUnsafe(); } - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); // deletes for (int i = 0; i < 10; i++) { RowUpdateBuilder.deleteRow(cfs.metadata(), 1, key, String.valueOf(i)).applyUnsafe(); } - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); // resurrect one column RowUpdateBuilder builder = new RowUpdateBuilder(cfs.metadata(), 2, key); @@ -117,7 +117,7 @@ public void testMajorCompactionPurge() .add("val", ByteBufferUtil.EMPTY_BYTE_BUFFER) .build().applyUnsafe(); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); // major compact and test that all columns but the resurrected one is completely gone FBUtilities.waitOnFutures(CompactionManager.instance.submitMaximal(cfs, Integer.MAX_VALUE, false)); @@ -146,14 +146,14 @@ public void testMajorCompactionPurgeTombstonesWithMaxTimestamp() .add("val", ByteBufferUtil.EMPTY_BYTE_BUFFER) .build().applyUnsafe(); } - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); // deletes for (int i = 0; i < 10; i++) { RowUpdateBuilder.deleteRow(cfs.metadata(), Long.MAX_VALUE, key, String.valueOf(i)).applyUnsafe(); } - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); // major compact - tombstones should be purged FBUtilities.waitOnFutures(CompactionManager.instance.submitMaximal(cfs, Integer.MAX_VALUE, false)); @@ -164,7 +164,7 @@ public void testMajorCompactionPurgeTombstonesWithMaxTimestamp() .add("val", ByteBufferUtil.EMPTY_BYTE_BUFFER) .build().applyUnsafe(); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); cfs.invalidateCachedPartition(dk(key)); @@ -191,13 +191,13 @@ public void testMajorCompactionPurgeTopLevelTombstoneWithMaxTimestamp() .add("val", ByteBufferUtil.EMPTY_BYTE_BUFFER) .build().applyUnsafe(); } - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); new Mutation.PartitionUpdateCollector(KEYSPACE1, dk(key)) .add(PartitionUpdate.fullPartitionDelete(cfs.metadata(), dk(key), Long.MAX_VALUE, FBUtilities.nowInSeconds())) .build() .applyUnsafe(); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); // major compact - tombstones should be purged FBUtilities.waitOnFutures(CompactionManager.instance.submitMaximal(cfs, Integer.MAX_VALUE, false)); @@ -208,7 +208,7 @@ public void testMajorCompactionPurgeTopLevelTombstoneWithMaxTimestamp() .add("val", ByteBufferUtil.EMPTY_BYTE_BUFFER) .build().applyUnsafe(); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); cfs.invalidateCachedPartition(dk(key)); @@ -235,11 +235,11 @@ public void testMajorCompactionPurgeRangeTombstoneWithMaxTimestamp() .add("val", ByteBufferUtil.EMPTY_BYTE_BUFFER) .build().applyUnsafe(); } - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); new RowUpdateBuilder(cfs.metadata(), Long.MAX_VALUE, dk(key)) .addRangeTombstone(String.valueOf(0), String.valueOf(9)).build().applyUnsafe(); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); // major compact - tombstones should be purged FBUtilities.waitOnFutures(CompactionManager.instance.submitMaximal(cfs, Integer.MAX_VALUE, false)); @@ -250,7 +250,7 @@ public void testMajorCompactionPurgeRangeTombstoneWithMaxTimestamp() .add("val", ByteBufferUtil.EMPTY_BYTE_BUFFER) .build().applyUnsafe(); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); cfs.invalidateCachedPartition(dk(key)); @@ -278,7 +278,7 @@ public void testMinorCompactionPurge() .add("val", ByteBufferUtil.EMPTY_BYTE_BUFFER) .build().applyUnsafe(); } - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); // deletes for (int i = 0; i < 10; i++) @@ -286,7 +286,7 @@ public void testMinorCompactionPurge() RowUpdateBuilder.deleteRow(cfs.metadata(), 1, key, String.valueOf(i)).applyUnsafe(); } - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); } DecoratedKey key1 = Util.dk("key1"); @@ -294,7 +294,7 @@ public void testMinorCompactionPurge() // flush, remember the current sstable and then resurrect one column // for first key. Then submit minor compaction on remembered sstables. - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); Collection sstablesIncomplete = cfs.getLiveSSTables(); RowUpdateBuilder builder = new RowUpdateBuilder(cfs.metadata(), 2, "key1"); @@ -302,10 +302,11 @@ public void testMinorCompactionPurge() .add("val", ByteBufferUtil.EMPTY_BYTE_BUFFER) .build().applyUnsafe(); - cfs.forceBlockingFlush(); - List tasks = cfs.getCompactionStrategyManager().getUserDefinedTasks(sstablesIncomplete, Integer.MAX_VALUE); - assertEquals(1, tasks.size()); - tasks.get(0).execute(ActiveCompactionsTracker.NOOP); + cfs.forceBlockingFlushToSSTable(); + try (CompactionTasks tasks = cfs.getCompactionStrategyManager().getUserDefinedTasks(sstablesIncomplete, Integer.MAX_VALUE)) + { + Iterables.getOnlyElement(tasks).execute(ActiveCompactionsTracker.NOOP); + } // verify that minor compaction does GC when key is provably not // present in a non-compacted sstable @@ -342,21 +343,22 @@ public void testMinTimestampPurge() .add("val", ByteBufferUtil.EMPTY_BYTE_BUFFER) .build().applyUnsafe(); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); // delete c1 RowUpdateBuilder.deleteRow(cfs.metadata(), 10, key3, "c1").applyUnsafe(); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); Collection sstablesIncomplete = cfs.getLiveSSTables(); // delete c2 so we have new delete in a diffrent SSTable RowUpdateBuilder.deleteRow(cfs.metadata(), 9, key3, "c2").applyUnsafe(); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); // compact the sstables with the c1/c2 data and the c1 tombstone - List tasks = cfs.getCompactionStrategyManager().getUserDefinedTasks(sstablesIncomplete, Integer.MAX_VALUE); - assertEquals(1, tasks.size()); - tasks.get(0).execute(ActiveCompactionsTracker.NOOP); + try (CompactionTasks tasks = cfs.getCompactionStrategyManager().getUserDefinedTasks(sstablesIncomplete, Integer.MAX_VALUE)) + { + Iterables.getOnlyElement(tasks).execute(ActiveCompactionsTracker.NOOP); + } // We should have both the c1 and c2 tombstones still. Since the min timestamp in the c2 tombstone // sstable is older than the c1 tombstone, it is invalid to throw out the c1 tombstone. @@ -391,7 +393,7 @@ public void testCompactionPurgeOneFile() throws ExecutionException, InterruptedE { RowUpdateBuilder.deleteRow(cfs.metadata(), 1, key, String.valueOf(i)).applyUnsafe(); } - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); assertEquals(String.valueOf(cfs.getLiveSSTables()), 1, cfs.getLiveSSTables().size()); // inserts & deletes were in the same memtable -> only deletes in sstable // compact and test that the row is completely gone @@ -436,7 +438,7 @@ public void testCompactionPurgeCachedRow() throws ExecutionException, Interrupte assertFalse(Util.getOnlyPartitionUnfiltered(Util.cmd(cfs, key).build()).isEmpty()); // flush and major compact - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); Util.compactAll(cfs, Integer.MAX_VALUE).get(); // Since we've force purging (by passing MAX_VALUE for gc_before), the row should have been invalidated and we should have no deletion info anymore @@ -472,7 +474,7 @@ public void testCompactionPurgeTombstonedRow() throws ExecutionException, Interr assertFalse(partition.partitionLevelDeletion().isLive()); // flush and major compact (with tombstone purging) - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); Util.compactAll(cfs, Integer.MAX_VALUE).get(); assertFalse(Util.getOnlyPartitionUnfiltered(Util.cmd(cfs, key).build()).isEmpty()); @@ -502,14 +504,14 @@ public void testRowTombstoneObservedBeforePurging() // write a row out to one sstable QueryProcessor.executeInternal(String.format("INSERT INTO %s.%s (k, v1, v2) VALUES (%d, '%s', %d)", keyspace, table, 1, "foo", 1)); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); UntypedResultSet result = QueryProcessor.executeInternal(String.format("SELECT * FROM %s.%s WHERE k = %d", keyspace, table, 1)); assertEquals(1, result.size()); // write a row tombstone out to a second sstable QueryProcessor.executeInternal(String.format("DELETE FROM %s.%s WHERE k = %d", keyspace, table, 1)); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); // basic check that the row is considered deleted assertEquals(2, cfs.getLiveSSTables().size()); @@ -527,14 +529,14 @@ public void testRowTombstoneObservedBeforePurging() // write a row out to one sstable QueryProcessor.executeInternal(String.format("INSERT INTO %s.%s (k, v1, v2) VALUES (%d, '%s', %d)", keyspace, table, 1, "foo", 1)); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); assertEquals(2, cfs.getLiveSSTables().size()); result = QueryProcessor.executeInternal(String.format("SELECT * FROM %s.%s WHERE k = %d", keyspace, table, 1)); assertEquals(1, result.size()); // write a row tombstone out to a different sstable QueryProcessor.executeInternal(String.format("DELETE FROM %s.%s WHERE k = %d", keyspace, table, 1)); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); // compact the two sstables with a gcBefore that *does* allow the row tombstone to be purged FBUtilities.waitOnFutures(CompactionManager.instance.submitMaximal(cfs, (int) (System.currentTimeMillis() / 1000) + 10000, false)); diff --git a/test/unit/org/apache/cassandra/db/compaction/CompactionsTest.java b/test/unit/org/apache/cassandra/db/compaction/CompactionsTest.java index 941ef13eb250..a83b11e219ef 100644 --- a/test/unit/org/apache/cassandra/db/compaction/CompactionsTest.java +++ b/test/unit/org/apache/cassandra/db/compaction/CompactionsTest.java @@ -156,7 +156,7 @@ public void testSingleSSTableCompaction() throws Exception long timestamp = populate(KEYSPACE1, CF_DENSE1, 0, 9, 3); //ttl=3s - store.forceBlockingFlush(); + store.forceBlockingFlushToSSTable(); assertEquals(1, store.getLiveSSTables().size()); long originalSize = store.getLiveSSTables().iterator().next().uncompressedLength(); @@ -196,11 +196,11 @@ public void testSuperColumnTombstones() .clustering(ByteBufferUtil.bytes("cols")) .add("val", "val1") .build().applyUnsafe(); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); // shadow the subcolumn with a supercolumn tombstone RowUpdateBuilder.deleteRow(table, FBUtilities.timestampMicros(), key.getKey(), ByteBufferUtil.bytes("cols")).applyUnsafe(); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); CompactionManager.instance.performMaximal(cfs, false); assertEquals(1, cfs.getLiveSSTables().size()); @@ -236,11 +236,11 @@ public void testUncheckedTombstoneSizeTieredCompaction() throws Exception //Populate sstable1 with with keys [0..9] populate(KEYSPACE1, CF_STANDARD1, 0, 9, 3); //ttl=3s - store.forceBlockingFlush(); + store.forceBlockingFlushToSSTable(); //Populate sstable2 with with keys [10..19] (keys do not overlap with SSTable1) long timestamp2 = populate(KEYSPACE1, CF_STANDARD1, 10, 19, 3); //ttl=3s - store.forceBlockingFlush(); + store.forceBlockingFlushToSSTable(); assertEquals(2, store.getLiveSSTables().size()); @@ -330,7 +330,7 @@ public void testUserDefinedCompaction() throws Exception .add("val", "val1") .build().applyUnsafe(); } - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); Collection sstables = cfs.getLiveSSTables(); assertEquals(1, sstables.size()); @@ -365,7 +365,7 @@ public static void writeSSTableWithRangeTombstoneMaskingOneColumn(ColumnFamilySt notYetDeletedRowUpdateBuilder.clustering("02").add("val", "a"); //Range tombstone doesn't cover this (timestamp 3 > 2) notYetDeletedRowUpdateBuilder.build().applyUnsafe(); } - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); } @Test @@ -450,7 +450,7 @@ private void testDontPurgeAccidentally(String k, String cfname) throws Interrupt rowUpdateBuilder.clustering("c").add("val", "a"); rowUpdateBuilder.build().applyUnsafe(); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); Collection sstablesBefore = cfs.getLiveSSTables(); @@ -468,7 +468,7 @@ private void testDontPurgeAccidentally(String k, String cfname) throws Interrupt // Sleep one second so that the removal is indeed purgeable even with gcgrace == 0 Thread.sleep(1000); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); Collection sstablesAfter = cfs.getLiveSSTables(); Collection toCompact = new ArrayList(); @@ -550,7 +550,7 @@ public void testNeedsCleanup() insertRowWithKey(i + 100); insertRowWithKey(i + 200); } - store.forceBlockingFlush(); + store.forceBlockingFlushToSSTable(); assertEquals(1, store.getLiveSSTables().size()); SSTableReader sstable = store.getLiveSSTables().iterator().next(); diff --git a/test/unit/org/apache/cassandra/db/compaction/DateTieredCompactionStrategyTest.java b/test/unit/org/apache/cassandra/db/compaction/DateTieredCompactionStrategyTest.java index f75842d05a85..6d04ccf71648 100644 --- a/test/unit/org/apache/cassandra/db/compaction/DateTieredCompactionStrategyTest.java +++ b/test/unit/org/apache/cassandra/db/compaction/DateTieredCompactionStrategyTest.java @@ -231,9 +231,9 @@ public void testPrepBucket() .clustering("column") .add("val", value).build().applyUnsafe(); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); } - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); List sstrs = new ArrayList<>(cfs.getLiveSSTables()); @@ -267,9 +267,9 @@ public void testFilterOldSSTables() .clustering("column") .add("val", value).build().applyUnsafe(); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); } - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); Iterable filtered; List sstrs = new ArrayList<>(cfs.getLiveSSTables()); @@ -304,7 +304,7 @@ public void testDropExpiredSSTables() throws InterruptedException .clustering("column") .add("val", value).build().applyUnsafe(); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); SSTableReader expiredSSTable = cfs.getLiveSSTables().iterator().next(); Thread.sleep(10); @@ -313,7 +313,7 @@ public void testDropExpiredSSTables() throws InterruptedException .clustering("column") .add("val", value).build().applyUnsafe(); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); assertEquals(cfs.getLiveSSTables().size(), 2); Map options = new HashMap<>(); @@ -357,7 +357,7 @@ public void testSTCSBigWindow() .clustering("column") .add("val", bigValue).build().applyUnsafe(); } - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); } // and small ones: for (int r = 0; r < numSSTables / 2; r++) @@ -366,7 +366,7 @@ public void testSTCSBigWindow() new RowUpdateBuilder(cfs.metadata(), timestamp, key.getKey()) .clustering("column") .add("val", value).build().applyUnsafe(); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); } Map options = new HashMap<>(); options.put(SizeTieredCompactionStrategyOptions.MIN_SSTABLE_SIZE_KEY, "1"); diff --git a/test/unit/org/apache/cassandra/db/compaction/LeveledCompactionStrategyTest.java b/test/unit/org/apache/cassandra/db/compaction/LeveledCompactionStrategyTest.java index 6c75e7bc63cf..7488a592a982 100644 --- a/test/unit/org/apache/cassandra/db/compaction/LeveledCompactionStrategyTest.java +++ b/test/unit/org/apache/cassandra/db/compaction/LeveledCompactionStrategyTest.java @@ -127,7 +127,7 @@ public void testGrouperLevels() throws Exception{ for (int c = 0; c < columns; c++) update.newRow("column" + c).add("val", value); update.applyUnsafe(); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); } waitForLeveling(cfs); @@ -183,7 +183,7 @@ public void testValidationMultipleSSTablePerLevel() throws Exception for (int c = 0; c < columns; c++) update.newRow("column" + c).add("val", value); update.applyUnsafe(); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); } waitForLeveling(cfs); @@ -257,7 +257,7 @@ public void testCompactionProgress() throws Exception for (int c = 0; c < columns; c++) update.newRow("column" + c).add("val", value); update.applyUnsafe(); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); } waitForLeveling(cfs); @@ -294,9 +294,9 @@ public void testMutateLevel() throws Exception for (int c = 0; c < columns; c++) update.newRow("column" + c).add("val", value); update.applyUnsafe(); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); } - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); LeveledCompactionStrategy strategy = (LeveledCompactionStrategy) cfs.getCompactionStrategyManager().getStrategies().get(1).get(0); cfs.forceMajorCompaction(); @@ -335,7 +335,7 @@ public void testNewRepairedSSTable() throws Exception for (int c = 0; c < columns; c++) update.newRow("column" + c).add("val", value); update.applyUnsafe(); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); } waitForLeveling(cfs); cfs.disableAutoCompaction(); @@ -414,7 +414,7 @@ public void testTokenRangeCompaction() throws Exception update.newRow("column" + c).add("val", value); update.applyUnsafe(); } - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); } // create 20 more sstables with 10 containing data for key1 and other 10 containing data for key2 @@ -424,7 +424,7 @@ public void testTokenRangeCompaction() throws Exception for (int c = 0; c < columns; c++) update.newRow("column" + c).add("val", value); update.applyUnsafe(); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); } } @@ -475,7 +475,7 @@ public void testCompactionCandidateOrdering() throws Exception for (int c = 0; c < columns; c++) update.newRow("column" + c).add("val", value); update.applyUnsafe(); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); } LeveledCompactionStrategy strategy = (LeveledCompactionStrategy) (cfs.getCompactionStrategyManager()).getStrategies().get(1).get(0); // get readers for level 0 sstables diff --git a/test/unit/org/apache/cassandra/db/compaction/NeverPurgeTest.java b/test/unit/org/apache/cassandra/db/compaction/NeverPurgeTest.java index 0d5bc81b024d..905f415e35b5 100644 --- a/test/unit/org/apache/cassandra/db/compaction/NeverPurgeTest.java +++ b/test/unit/org/apache/cassandra/db/compaction/NeverPurgeTest.java @@ -72,13 +72,13 @@ public void minorNeverPurgeTombstonesTest() throws Throwable { execute("INSERT INTO %s (a, b, c) VALUES (" + j + ", 2, '3')"); } - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); } execute("UPDATE %s SET c = null WHERE a=1 AND b=2"); execute("DELETE FROM %s WHERE a=2 AND b=2"); execute("DELETE FROM %s WHERE a=3"); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); cfs.enableAutoCompaction(); while (cfs.getLiveSSTables().size() > 1 || !cfs.getTracker().getCompacting().isEmpty()) Thread.sleep(100); @@ -92,7 +92,7 @@ private void testHelper(String deletionStatement) throws Throwable execute("INSERT INTO %s (a, b, c) VALUES (1, 2, '3')"); execute(deletionStatement); Thread.sleep(1000); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); cfs.forceMajorCompaction(); verifyContainsTombstones(cfs.getLiveSSTables(), 1); } diff --git a/test/unit/org/apache/cassandra/db/compaction/OneCompactionTest.java b/test/unit/org/apache/cassandra/db/compaction/OneCompactionTest.java index 0c469dc534b3..41fdadec8cdc 100644 --- a/test/unit/org/apache/cassandra/db/compaction/OneCompactionTest.java +++ b/test/unit/org/apache/cassandra/db/compaction/OneCompactionTest.java @@ -71,7 +71,7 @@ private void testCompaction(String columnFamilyName, int insertsPerTable) .applyUnsafe(); inserted.add(key); - store.forceBlockingFlush(); + store.forceBlockingFlushToSSTable(); assertEquals(inserted.size(), Util.getAll(Util.cmd(store).build()).size()); } CompactionManager.instance.performMaximal(store, false); diff --git a/test/unit/org/apache/cassandra/db/compaction/PendingRepairManagerTest.java b/test/unit/org/apache/cassandra/db/compaction/PendingRepairManagerTest.java index 4e645fd9f805..9f4cf8de690c 100644 --- a/test/unit/org/apache/cassandra/db/compaction/PendingRepairManagerTest.java +++ b/test/unit/org/apache/cassandra/db/compaction/PendingRepairManagerTest.java @@ -20,7 +20,6 @@ import java.util.Collection; import java.util.Collections; -import java.util.List; import java.util.UUID; import com.google.common.collect.Lists; @@ -228,15 +227,11 @@ public void userDefinedTaskTest() SSTableReader sstable = makeSSTable(true); mutateRepaired(sstable, repairId, false); prm.addSSTable(sstable); - List tasks = csm.getUserDefinedTasks(Collections.singleton(sstable), 100); - try + + try (CompactionTasks tasks = csm.getUserDefinedTasks(Collections.singleton(sstable), 100)) { Assert.assertEquals(1, tasks.size()); } - finally - { - tasks.stream().forEach(t -> t.transaction.abort()); - } } @Test @@ -252,15 +247,10 @@ public void mixedPendingSessionsTest() mutateRepaired(sstable2, repairId2, false); prm.addSSTable(sstable); prm.addSSTable(sstable2); - List tasks = csm.getUserDefinedTasks(Lists.newArrayList(sstable, sstable2), 100); - try + try (CompactionTasks tasks = csm.getUserDefinedTasks(Lists.newArrayList(sstable, sstable2), 100)) { Assert.assertEquals(2, tasks.size()); } - finally - { - tasks.stream().forEach(t -> t.transaction.abort()); - } } /** diff --git a/test/unit/org/apache/cassandra/db/compaction/SingleSSTableLCSTaskTest.java b/test/unit/org/apache/cassandra/db/compaction/SingleSSTableLCSTaskTest.java index 61cf302d6c51..38499cde415c 100644 --- a/test/unit/org/apache/cassandra/db/compaction/SingleSSTableLCSTaskTest.java +++ b/test/unit/org/apache/cassandra/db/compaction/SingleSSTableLCSTaskTest.java @@ -42,7 +42,7 @@ public void basicTest() throws Throwable createTable("create table %s (id int primary key, t text) with compaction = {'class':'LeveledCompactionStrategy','single_sstable_uplevel':true}"); ColumnFamilyStore cfs = getCurrentColumnFamilyStore(); execute("insert into %s (id, t) values (1, 'meep')"); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); SSTableReader sstable = cfs.getLiveSSTables().iterator().next(); try (LifecycleTransaction txn = cfs.getTracker().tryModify(sstable, OperationType.COMPACTION)) @@ -95,7 +95,7 @@ private void compactionTestHelper(boolean singleSSTUplevel) throws Throwable execute("insert into %s (id, id2, t) values (?, ?, ?)", i, j, value); } if (i % 100 == 0) - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); } // now we have a bunch of data in L0, first compaction will be a normal one, containing all sstables: LeveledCompactionStrategy lcs = (LeveledCompactionStrategy) cfs.getCompactionStrategyManager().getUnrepairedUnsafe().first(); @@ -123,7 +123,7 @@ public void corruptMetadataTest() throws Throwable createTable("create table %s (id int primary key, t text) with compaction = {'class':'LeveledCompactionStrategy','single_sstable_uplevel':true}"); ColumnFamilyStore cfs = getCurrentColumnFamilyStore(); execute("insert into %s (id, t) values (1, 'meep')"); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); SSTableReader sstable = cfs.getLiveSSTables().iterator().next(); String filenameToCorrupt = sstable.descriptor.filenameFor(Component.STATS); diff --git a/test/unit/org/apache/cassandra/db/compaction/SizeTieredCompactionStrategyTest.java b/test/unit/org/apache/cassandra/db/compaction/SizeTieredCompactionStrategyTest.java index 00c4a86e0dd0..a7313a5d4945 100644 --- a/test/unit/org/apache/cassandra/db/compaction/SizeTieredCompactionStrategyTest.java +++ b/test/unit/org/apache/cassandra/db/compaction/SizeTieredCompactionStrategyTest.java @@ -165,9 +165,9 @@ public void testPrepBucket() throws Exception new RowUpdateBuilder(cfs.metadata(), 0, key) .clustering("column").add("val", value) .build().applyUnsafe(); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); } - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); List sstrs = new ArrayList<>(cfs.getLiveSSTables()); Pair, Double> bucket; diff --git a/test/unit/org/apache/cassandra/db/compaction/TTLExpiryTest.java b/test/unit/org/apache/cassandra/db/compaction/TTLExpiryTest.java index a2352fcf02aa..31dbb09441ae 100644 --- a/test/unit/org/apache/cassandra/db/compaction/TTLExpiryTest.java +++ b/test/unit/org/apache/cassandra/db/compaction/TTLExpiryTest.java @@ -94,7 +94,7 @@ public void testAggressiveFullyExpired() .add("col2", ByteBufferUtil.EMPTY_BYTE_BUFFER) .build() .applyUnsafe(); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); new RowUpdateBuilder(cfs.metadata(), 2L, 1, key) .add("col1", ByteBufferUtil.EMPTY_BYTE_BUFFER) .build() @@ -105,7 +105,7 @@ public void testAggressiveFullyExpired() .build() .applyUnsafe(); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); new RowUpdateBuilder(cfs.metadata(), 4L, 1, key) .add("col1", ByteBufferUtil.EMPTY_BYTE_BUFFER) @@ -117,7 +117,7 @@ public void testAggressiveFullyExpired() .build() .applyUnsafe(); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); new RowUpdateBuilder(cfs.metadata(), 6L, 3, key) @@ -130,7 +130,7 @@ public void testAggressiveFullyExpired() .build() .applyUnsafe(); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); Set sstables = Sets.newHashSet(cfs.getLiveSSTables()); int now = (int)(System.currentTimeMillis() / 1000); @@ -173,7 +173,7 @@ public void testSimpleExpire(boolean force10944Bug) throws InterruptedException .build() .applyUnsafe(); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); new RowUpdateBuilder(cfs.metadata(), timestamp, 1, key) .add("col2", ByteBufferUtil.EMPTY_BYTE_BUFFER) @@ -183,7 +183,7 @@ public void testSimpleExpire(boolean force10944Bug) throws InterruptedException .applyUnsafe(); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); // To reproduce #10944, we need to avoid the optimization that get rid of full sstable because everything // is known to be gcAble, so keep some data non-expiring in that case. new RowUpdateBuilder(cfs.metadata(), timestamp, force10944Bug ? 0 : 1, key) @@ -192,14 +192,14 @@ public void testSimpleExpire(boolean force10944Bug) throws InterruptedException .applyUnsafe(); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); new RowUpdateBuilder(cfs.metadata(), timestamp, 1, key) .add("col311", ByteBufferUtil.EMPTY_BYTE_BUFFER) .build() .applyUnsafe(); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); Thread.sleep(2000); // wait for ttl to expire assertEquals(4, cfs.getLiveSSTables().size()); cfs.enableAutoCompaction(true); @@ -221,24 +221,24 @@ public void testNoExpire() throws InterruptedException, IOException .build() .applyUnsafe(); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); new RowUpdateBuilder(cfs.metadata(), timestamp, 1, key) .add("col2", ByteBufferUtil.EMPTY_BYTE_BUFFER) .build() .applyUnsafe(); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); new RowUpdateBuilder(cfs.metadata(), timestamp, 1, key) .add("col3", ByteBufferUtil.EMPTY_BYTE_BUFFER) .build() .applyUnsafe(); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); String noTTLKey = "nottl"; new RowUpdateBuilder(cfs.metadata(), timestamp, noTTLKey) .add("col311", ByteBufferUtil.EMPTY_BYTE_BUFFER) .build() .applyUnsafe(); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); Thread.sleep(2000); // wait for ttl to expire assertEquals(4, cfs.getLiveSSTables().size()); cfs.enableAutoCompaction(true); @@ -270,7 +270,7 @@ public void testCheckForExpiredSSTableBlockers() throws InterruptedException .build() .applyUnsafe(); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); SSTableReader blockingSSTable = cfs.getSSTables(SSTableSet.LIVE).iterator().next(); for (int i = 0; i < 10; i++) { @@ -279,7 +279,7 @@ public void testCheckForExpiredSSTableBlockers() throws InterruptedException .delete("col1") .build() .applyUnsafe(); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); } Multimap blockers = SSTableExpiredBlockers.checkForExpiredSSTableBlockers(cfs.getSSTables(SSTableSet.LIVE), (int) (System.currentTimeMillis() / 1000) + 100); assertEquals(1, blockers.keySet().size()); diff --git a/test/unit/org/apache/cassandra/db/compaction/TimeWindowCompactionStrategyTest.java b/test/unit/org/apache/cassandra/db/compaction/TimeWindowCompactionStrategyTest.java index 89dd2f59b162..69a43fc94791 100644 --- a/test/unit/org/apache/cassandra/db/compaction/TimeWindowCompactionStrategyTest.java +++ b/test/unit/org/apache/cassandra/db/compaction/TimeWindowCompactionStrategyTest.java @@ -168,7 +168,7 @@ public void testPrepBucket() .clustering("column") .add("val", value).build().applyUnsafe(); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); } // Decrement the timestamp to simulate a timestamp in the past hour for (int r = 3; r < 5; r++) @@ -178,10 +178,10 @@ public void testPrepBucket() new RowUpdateBuilder(cfs.metadata(), r, key.getKey()) .clustering("column") .add("val", value).build().applyUnsafe(); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); } - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); HashMultimap buckets = HashMultimap.create(); List sstrs = new ArrayList<>(cfs.getLiveSSTables()); @@ -220,7 +220,7 @@ public void testPrepBucket() .clustering("column") .add("val", value).build().applyUnsafe(); } - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); } // Reset the buckets, overfill it now @@ -252,7 +252,7 @@ public void testDropExpiredSSTables() throws InterruptedException .clustering("column") .add("val", value).build().applyUnsafe(); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); SSTableReader expiredSSTable = cfs.getLiveSSTables().iterator().next(); Thread.sleep(10); @@ -261,7 +261,7 @@ public void testDropExpiredSSTables() throws InterruptedException .clustering("column") .add("val", value).build().applyUnsafe(); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); assertEquals(cfs.getLiveSSTables().size(), 2); Map options = new HashMap<>(); @@ -300,7 +300,7 @@ public void testDropOverlappingExpiredSSTables() throws InterruptedException .clustering("column") .add("val", value).build().applyUnsafe(); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); SSTableReader expiredSSTable = cfs.getLiveSSTables().iterator().next(); Thread.sleep(10); @@ -312,7 +312,7 @@ public void testDropOverlappingExpiredSSTables() throws InterruptedException .clustering("column") .add("val", value).build().applyUnsafe(); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); assertEquals(cfs.getLiveSSTables().size(), 2); Map options = new HashMap<>(); diff --git a/test/unit/org/apache/cassandra/db/monitoring/MonitoringTaskTest.java b/test/unit/org/apache/cassandra/db/monitoring/MonitoringTaskTest.java index acc988f7228f..454d0b4915c7 100644 --- a/test/unit/org/apache/cassandra/db/monitoring/MonitoringTaskTest.java +++ b/test/unit/org/apache/cassandra/db/monitoring/MonitoringTaskTest.java @@ -32,6 +32,11 @@ import org.junit.BeforeClass; import org.junit.Test; +import org.apache.cassandra.utils.ApproximateTime; + +import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static java.util.concurrent.TimeUnit.NANOSECONDS; +import static org.apache.cassandra.utils.MonotonicClock.approxTime; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; @@ -39,8 +44,8 @@ public class MonitoringTaskTest { - private static final long timeout = 100; - private static final long slowTimeout = 10; + private static final long timeout = MILLISECONDS.toNanos(100); + private static final long slowTimeout = MILLISECONDS.toNanos(10); private static final long MAX_SPIN_TIME_NANOS = TimeUnit.SECONDS.toNanos(5); @@ -90,8 +95,8 @@ private static void waitForOperationsToComplete(Monitorable... operations) throw private static void waitForOperationsToComplete(List operations) throws InterruptedException { - long timeout = operations.stream().map(Monitorable::timeout).reduce(0L, Long::max); - Thread.sleep(timeout * 2 + ApproximateTime.precision()); + long timeout = operations.stream().map(Monitorable::timeoutNanos).reduce(0L, Long::max); + Thread.sleep(NANOSECONDS.toMillis(timeout * 2 + approxTime.error())); long start = System.nanoTime(); while(System.nanoTime() - start <= MAX_SPIN_TIME_NANOS) @@ -109,8 +114,8 @@ private static void waitForOperationsToBeReportedAsSlow(Monitorable... operation private static void waitForOperationsToBeReportedAsSlow(List operations) throws InterruptedException { - long timeout = operations.stream().map(Monitorable::slowTimeout).reduce(0L, Long::max); - Thread.sleep(timeout * 2 + ApproximateTime.precision()); + long timeout = operations.stream().map(Monitorable::slowTimeoutNanos).reduce(0L, Long::max); + Thread.sleep(NANOSECONDS.toMillis(timeout * 2 + approxTime.error())); long start = System.nanoTime(); while(System.nanoTime() - start <= MAX_SPIN_TIME_NANOS) @@ -124,7 +129,7 @@ private static void waitForOperationsToBeReportedAsSlow(List operat @Test public void testAbort() throws InterruptedException { - Monitorable operation = new TestMonitor("Test abort", System.currentTimeMillis(), false, timeout, slowTimeout); + Monitorable operation = new TestMonitor("Test abort", System.nanoTime(), false, timeout, slowTimeout); waitForOperationsToComplete(operation); assertTrue(operation.isAborted()); @@ -135,7 +140,7 @@ public void testAbort() throws InterruptedException @Test public void testAbortIdemPotent() throws InterruptedException { - Monitorable operation = new TestMonitor("Test abort", System.currentTimeMillis(), false, timeout, slowTimeout); + Monitorable operation = new TestMonitor("Test abort", System.nanoTime(), false, timeout, slowTimeout); waitForOperationsToComplete(operation); assertTrue(operation.abort()); @@ -148,7 +153,7 @@ public void testAbortIdemPotent() throws InterruptedException @Test public void testAbortCrossNode() throws InterruptedException { - Monitorable operation = new TestMonitor("Test for cross node", System.currentTimeMillis(), true, timeout, slowTimeout); + Monitorable operation = new TestMonitor("Test for cross node", System.nanoTime(), true, timeout, slowTimeout); waitForOperationsToComplete(operation); assertTrue(operation.isAborted()); @@ -159,7 +164,7 @@ public void testAbortCrossNode() throws InterruptedException @Test public void testComplete() throws InterruptedException { - Monitorable operation = new TestMonitor("Test complete", System.currentTimeMillis(), false, timeout, slowTimeout); + Monitorable operation = new TestMonitor("Test complete", System.nanoTime(), false, timeout, slowTimeout); operation.complete(); waitForOperationsToComplete(operation); @@ -171,7 +176,7 @@ public void testComplete() throws InterruptedException @Test public void testCompleteIdemPotent() throws InterruptedException { - Monitorable operation = new TestMonitor("Test complete", System.currentTimeMillis(), false, timeout, slowTimeout); + Monitorable operation = new TestMonitor("Test complete", System.nanoTime(), false, timeout, slowTimeout); operation.complete(); waitForOperationsToComplete(operation); @@ -185,7 +190,7 @@ public void testCompleteIdemPotent() throws InterruptedException @Test public void testReportSlow() throws InterruptedException { - Monitorable operation = new TestMonitor("Test report slow", System.currentTimeMillis(), false, timeout, slowTimeout); + Monitorable operation = new TestMonitor("Test report slow", System.nanoTime(), false, timeout, slowTimeout); waitForOperationsToBeReportedAsSlow(operation); assertTrue(operation.isSlow()); @@ -199,7 +204,7 @@ public void testReportSlow() throws InterruptedException public void testNoReportSlowIfZeroSlowTimeout() throws InterruptedException { // when the slow timeout is set to zero then operation won't be reported as slow - Monitorable operation = new TestMonitor("Test report slow disabled", System.currentTimeMillis(), false, timeout, 0); + Monitorable operation = new TestMonitor("Test report slow disabled", System.nanoTime(), false, timeout, 0); waitForOperationsToBeReportedAsSlow(operation); assertTrue(operation.isSlow()); @@ -212,7 +217,7 @@ public void testNoReportSlowIfZeroSlowTimeout() throws InterruptedException @Test public void testReport() throws InterruptedException { - Monitorable operation = new TestMonitor("Test report", System.currentTimeMillis(), false, timeout, slowTimeout); + Monitorable operation = new TestMonitor("Test report", System.nanoTime(), false, timeout, slowTimeout); waitForOperationsToComplete(operation); assertTrue(operation.isSlow()); @@ -220,10 +225,10 @@ public void testReport() throws InterruptedException assertFalse(operation.isCompleted()); // aborted operations are not logged as slow - assertFalse(MonitoringTask.instance.logSlowOperations(ApproximateTime.currentTimeMillis())); + assertFalse(MonitoringTask.instance.logSlowOperations(approxTime.now())); assertEquals(0, MonitoringTask.instance.getSlowOperations().size()); - assertTrue(MonitoringTask.instance.logFailedOperations(ApproximateTime.currentTimeMillis())); + assertTrue(MonitoringTask.instance.logFailedOperations(approxTime.now())); assertEquals(0, MonitoringTask.instance.getFailedOperations().size()); } @@ -233,20 +238,20 @@ public void testRealScheduling() throws InterruptedException MonitoringTask.instance = MonitoringTask.make(10, -1); try { - Monitorable operation1 = new TestMonitor("Test report 1", System.currentTimeMillis(), false, timeout, slowTimeout); + Monitorable operation1 = new TestMonitor("Test report 1", System.nanoTime(), false, timeout, slowTimeout); waitForOperationsToComplete(operation1); assertTrue(operation1.isAborted()); assertFalse(operation1.isCompleted()); - Monitorable operation2 = new TestMonitor("Test report 2", System.currentTimeMillis(), false, timeout, slowTimeout); + Monitorable operation2 = new TestMonitor("Test report 2", System.nanoTime(), false, timeout, slowTimeout); waitForOperationsToBeReportedAsSlow(operation2); operation2.complete(); assertFalse(operation2.isAborted()); assertTrue(operation2.isCompleted()); - Thread.sleep(ApproximateTime.precision() + 500); + Thread.sleep(2 * NANOSECONDS.toMillis(approxTime.error()) + 500); assertEquals(0, MonitoringTask.instance.getFailedOperations().size()); assertEquals(0, MonitoringTask.instance.getSlowOperations().size()); } @@ -266,7 +271,7 @@ public void testMultipleThreads() throws InterruptedException for (int i = 0; i < opCount; i++) { executorService.submit(() -> - operations.add(new TestMonitor(UUID.randomUUID().toString(), System.currentTimeMillis(), false, timeout, slowTimeout)) + operations.add(new TestMonitor(UUID.randomUUID().toString(), System.nanoTime(), false, timeout, slowTimeout)) ); } @@ -311,14 +316,14 @@ private static void doTestMaxTimedoutOperations(int maxTimedoutOperations, for (int j = 0; j < numTimes; j++) { Monitorable operation1 = new TestMonitor(operationName, - System.currentTimeMillis(), + System.nanoTime(), false, timeout, slowTimeout); waitForOperationsToComplete(operation1); Monitorable operation2 = new TestMonitor(operationName, - System.currentTimeMillis(), + System.nanoTime(), false, timeout, slowTimeout); @@ -366,7 +371,7 @@ public void testMultipleThreadsSameNameFailed() throws InterruptedException try { Monitorable operation = new TestMonitor("Test testMultipleThreadsSameName failed", - System.currentTimeMillis(), + System.nanoTime(), false, timeout, slowTimeout); @@ -400,7 +405,7 @@ public void testMultipleThreadsSameNameSlow() throws InterruptedException try { Monitorable operation = new TestMonitor("Test testMultipleThreadsSameName slow", - System.currentTimeMillis(), + System.nanoTime(), false, timeout, slowTimeout); @@ -436,7 +441,7 @@ public void testMultipleThreadsNoFailedOps() throws InterruptedException try { Monitorable operation = new TestMonitor("Test thread " + Thread.currentThread().getName(), - System.currentTimeMillis(), + System.nanoTime(), false, timeout, slowTimeout); diff --git a/test/unit/org/apache/cassandra/db/repair/AbstractPendingAntiCompactionTest.java b/test/unit/org/apache/cassandra/db/repair/AbstractPendingAntiCompactionTest.java index 62b7db148465..5a4d7c171ce8 100644 --- a/test/unit/org/apache/cassandra/db/repair/AbstractPendingAntiCompactionTest.java +++ b/test/unit/org/apache/cassandra/db/repair/AbstractPendingAntiCompactionTest.java @@ -109,7 +109,7 @@ void makeSSTables(int num, ColumnFamilyStore cfs, int rowsPerSSTable) int val = i * rowsPerSSTable; // multiplied to prevent ranges from overlapping for (int j = 0; j < rowsPerSSTable; j++) QueryProcessor.executeInternal(String.format("INSERT INTO %s.%s (k, v) VALUES (?, ?)", ks, cfs.getTableName()), val + j, val + j); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); } Assert.assertEquals(num, cfs.getLiveSSTables().size()); } diff --git a/test/unit/org/apache/cassandra/db/repair/CompactionManagerGetSSTablesForValidationTest.java b/test/unit/org/apache/cassandra/db/repair/CompactionManagerGetSSTablesForValidationTest.java index 3b29cc5b50d7..08044e9c7346 100644 --- a/test/unit/org/apache/cassandra/db/repair/CompactionManagerGetSSTablesForValidationTest.java +++ b/test/unit/org/apache/cassandra/db/repair/CompactionManagerGetSSTablesForValidationTest.java @@ -93,7 +93,7 @@ private void makeSSTables() for (int i=0; i<3; i++) { QueryProcessor.executeInternal(String.format("INSERT INTO %s.%s (k, v) VALUES(?, ?)", ks, tbl), i, i); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); } Assert.assertEquals(3, cfs.getLiveSSTables().size()); diff --git a/test/unit/org/apache/cassandra/db/repair/PendingAntiCompactionTest.java b/test/unit/org/apache/cassandra/db/repair/PendingAntiCompactionTest.java index b140813f48e3..b91c315a0db1 100644 --- a/test/unit/org/apache/cassandra/db/repair/PendingAntiCompactionTest.java +++ b/test/unit/org/apache/cassandra/db/repair/PendingAntiCompactionTest.java @@ -122,12 +122,12 @@ public void successCase() throws Exception { QueryProcessor.executeInternal(String.format("INSERT INTO %s.%s (k, v) VALUES (?, ?)", ks, tbl), i, i); } - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); for (int i = 8; i < 12; i++) { QueryProcessor.executeInternal(String.format("INSERT INTO %s.%s (k, v) VALUES (?, ?)", ks, tbl), i, i); } - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); assertEquals(2, cfs.getLiveSSTables().size()); Token left = ByteOrderedPartitioner.instance.getToken(ByteBufferUtil.bytes((int) 6)); diff --git a/test/unit/org/apache/cassandra/db/rows/ThrottledUnfilteredIteratorTest.java b/test/unit/org/apache/cassandra/db/rows/ThrottledUnfilteredIteratorTest.java index cc886f1c24f8..bba7334dee4a 100644 --- a/test/unit/org/apache/cassandra/db/rows/ThrottledUnfilteredIteratorTest.java +++ b/test/unit/org/apache/cassandra/db/rows/ThrottledUnfilteredIteratorTest.java @@ -111,7 +111,7 @@ public void emptyPartitionDeletionTest() throws Throwable // flush and generate 1 sstable ColumnFamilyStore cfs = Keyspace.open(keyspace()).getColumnFamilyStore(currentTable()); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); cfs.disableAutoCompaction(); cfs.forceMajorCompaction(); @@ -146,7 +146,7 @@ public void emptyStaticTest() throws Throwable // flush and generate 1 sstable ColumnFamilyStore cfs = Keyspace.open(keyspace()).getColumnFamilyStore(currentTable()); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); cfs.disableAutoCompaction(); cfs.forceMajorCompaction(); @@ -204,7 +204,7 @@ else if (ck1 == ck2 - 1) // cell tombstone // flush and generate 1 sstable ColumnFamilyStore cfs = Keyspace.open(keyspace()).getColumnFamilyStore(currentTable()); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); cfs.disableAutoCompaction(); cfs.forceMajorCompaction(); @@ -623,7 +623,7 @@ public void testThrottledIteratorWithRangeDeletions() throws Exception new RowUpdateBuilder(cfs.metadata(), 1, key).addRangeTombstone(10, 22).build().applyUnsafe(); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); builder = UpdateBuilder.create(cfs.metadata(), key).withTimestamp(2); for (int i = 1; i < 40; i += 2) diff --git a/test/unit/org/apache/cassandra/db/streaming/CassandraEntireSSTableStreamWriterTest.java b/test/unit/org/apache/cassandra/db/streaming/CassandraEntireSSTableStreamWriterTest.java index 947f9687f4a0..baee189af016 100644 --- a/test/unit/org/apache/cassandra/db/streaming/CassandraEntireSSTableStreamWriterTest.java +++ b/test/unit/org/apache/cassandra/db/streaming/CassandraEntireSSTableStreamWriterTest.java @@ -42,10 +42,10 @@ import org.apache.cassandra.db.compaction.CompactionManager; import org.apache.cassandra.io.sstable.SSTableMultiWriter; import org.apache.cassandra.io.sstable.format.SSTableReader; +import org.apache.cassandra.io.util.DataInputBuffer; import org.apache.cassandra.locator.InetAddressAndPort; -import org.apache.cassandra.net.async.ByteBufDataInputPlus; -import org.apache.cassandra.net.async.ByteBufDataOutputStreamPlus; -import org.apache.cassandra.net.async.NonClosingDefaultFileRegion; +import org.apache.cassandra.net.SharedDefaultFileRegion; +import org.apache.cassandra.net.AsyncStreamingOutputPlus; import org.apache.cassandra.schema.CachingParams; import org.apache.cassandra.schema.KeyspaceParams; import org.apache.cassandra.streaming.DefaultConnectionFactory; @@ -100,7 +100,7 @@ public static void defineSchemaAndPrepareSSTable() .build() .applyUnsafe(); } - store.forceBlockingFlush(); + store.forceBlockingFlushToSSTable(); CompactionManager.instance.performMaximal(store, false); sstable = store.getLiveSSTables().iterator().next(); @@ -114,7 +114,7 @@ public void testBlockWriterOverWire() throws IOException CassandraEntireSSTableStreamWriter writer = new CassandraEntireSSTableStreamWriter(sstable, session, CassandraOutgoingFile.getComponentManifest(sstable)); EmbeddedChannel channel = new EmbeddedChannel(); - ByteBufDataOutputStreamPlus out = ByteBufDataOutputStreamPlus.create(session, channel, 1024 * 1024); + AsyncStreamingOutputPlus out = new AsyncStreamingOutputPlus(channel); writer.write(out); Queue msgs = channel.outboundMessages(); @@ -133,7 +133,7 @@ public void testBlockReadingAndWritingOverWire() throws Exception // This is needed as Netty releases the ByteBuffers as soon as the channel is flushed ByteBuf serializedFile = Unpooled.buffer(8192); EmbeddedChannel channel = createMockNettyChannel(serializedFile); - ByteBufDataOutputStreamPlus out = ByteBufDataOutputStreamPlus.create(session, channel, 1024 * 1024); + AsyncStreamingOutputPlus out = new AsyncStreamingOutputPlus(channel); writer.write(out); @@ -155,7 +155,7 @@ public void testBlockReadingAndWritingOverWire() throws Exception CassandraEntireSSTableStreamReader reader = new CassandraEntireSSTableStreamReader(new StreamMessageHeader(sstable.metadata().id, peer, session.planId(), 0, 0, 0, null), header, session); - SSTableMultiWriter sstableWriter = reader.read(new ByteBufDataInputPlus(serializedFile)); + SSTableMultiWriter sstableWriter = reader.read(new DataInputBuffer(serializedFile.nioBuffer(), false)); Collection newSstables = sstableWriter.finished(); assertEquals(1, newSstables.size()); @@ -188,7 +188,7 @@ public void close() throws IOException @Override public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { - ((NonClosingDefaultFileRegion) msg).transferTo(wbc, 0); + ((SharedDefaultFileRegion) msg).transferTo(wbc, 0); super.write(ctx, msg, promise); } }); diff --git a/test/unit/org/apache/cassandra/db/streaming/CassandraOutgoingFileTest.java b/test/unit/org/apache/cassandra/db/streaming/CassandraOutgoingFileTest.java index 5e443463d522..bf25f1a1931f 100644 --- a/test/unit/org/apache/cassandra/db/streaming/CassandraOutgoingFileTest.java +++ b/test/unit/org/apache/cassandra/db/streaming/CassandraOutgoingFileTest.java @@ -78,7 +78,7 @@ public static void defineSchemaAndPrepareSSTable() .build() .applyUnsafe(); } - store.forceBlockingFlush(); + store.forceBlockingFlushToSSTable(); CompactionManager.instance.performMaximal(store, false); sstable = store.getLiveSSTables().iterator().next(); diff --git a/test/unit/org/apache/cassandra/db/streaming/CassandraStreamManagerTest.java b/test/unit/org/apache/cassandra/db/streaming/CassandraStreamManagerTest.java index eb15e9aa7339..93f769b66479 100644 --- a/test/unit/org/apache/cassandra/db/streaming/CassandraStreamManagerTest.java +++ b/test/unit/org/apache/cassandra/db/streaming/CassandraStreamManagerTest.java @@ -113,7 +113,7 @@ private SSTableReader createSSTable(Runnable queryable) { Set before = cfs.getLiveSSTables(); queryable.run(); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); Set after = cfs.getLiveSSTables(); Set diff = Sets.difference(after, before); diff --git a/test/unit/org/apache/cassandra/db/streaming/ComponentManifestTest.java b/test/unit/org/apache/cassandra/db/streaming/ComponentManifestTest.java index f478a008d908..4909263440af 100644 --- a/test/unit/org/apache/cassandra/db/streaming/ComponentManifestTest.java +++ b/test/unit/org/apache/cassandra/db/streaming/ComponentManifestTest.java @@ -20,18 +20,15 @@ import java.io.EOFException; import java.io.IOException; +import java.nio.ByteBuffer; import java.util.LinkedHashMap; import org.junit.Test; -import io.netty.buffer.ByteBuf; -import io.netty.buffer.Unpooled; import org.apache.cassandra.io.sstable.Component; -import org.apache.cassandra.io.util.DataInputPlus; -import org.apache.cassandra.io.util.DataOutputPlus; +import org.apache.cassandra.io.util.DataInputBuffer; +import org.apache.cassandra.io.util.DataOutputBufferFixed; import org.apache.cassandra.net.MessagingService; -import org.apache.cassandra.net.async.ByteBufDataInputPlus; -import org.apache.cassandra.net.async.ByteBufDataOutputPlus; import org.apache.cassandra.serializers.SerializationUtils; import static org.junit.Assert.assertNotEquals; @@ -48,17 +45,17 @@ public void testSerialization() @Test(expected = EOFException.class) public void testSerialization_FailsOnBadBytes() throws IOException { - ByteBuf buf = Unpooled.buffer(512); + ByteBuffer buf = ByteBuffer.allocate(512); ComponentManifest expected = new ComponentManifest(new LinkedHashMap() {{ put(Component.DATA, 100L); }}); - DataOutputPlus output = new ByteBufDataOutputPlus(buf); - ComponentManifest.serializer.serialize(expected, output, MessagingService.VERSION_40); + DataOutputBufferFixed out = new DataOutputBufferFixed(buf); - buf.setInt(0, -100); + ComponentManifest.serializer.serialize(expected, out, MessagingService.VERSION_40); - DataInputPlus input = new ByteBufDataInputPlus(buf); - ComponentManifest actual = ComponentManifest.serializer.deserialize(input, MessagingService.VERSION_40); + buf.putInt(0, -100); + DataInputBuffer in = new DataInputBuffer(out.buffer(), false); + ComponentManifest actual = ComponentManifest.serializer.deserialize(in, MessagingService.VERSION_40); assertNotEquals(expected, actual); } } diff --git a/test/unit/org/apache/cassandra/db/view/ViewBuilderTaskTest.java b/test/unit/org/apache/cassandra/db/view/ViewBuilderTaskTest.java index 2341c730a423..bc306791a1c4 100644 --- a/test/unit/org/apache/cassandra/db/view/ViewBuilderTaskTest.java +++ b/test/unit/org/apache/cassandra/db/view/ViewBuilderTaskTest.java @@ -85,7 +85,7 @@ private void test(int indexOfStartToken, { // Truncate the materialized view (not the base table) cfs.viewManager.forceBlockingFlush(); - cfs.viewManager.truncateBlocking(cfs.forceBlockingFlush(), System.currentTimeMillis()); + cfs.viewManager.truncateBlocking(cfs.forceBlockingFlushToSSTable(), System.currentTimeMillis()); assertRowCount(execute("SELECT * FROM " + viewName), 0); // Get the tokens from the referenced inserted rows diff --git a/test/unit/org/apache/cassandra/db/virtual/SettingsTableTest.java b/test/unit/org/apache/cassandra/db/virtual/SettingsTableTest.java index 3e566617b9b2..d34878d68532 100644 --- a/test/unit/org/apache/cassandra/db/virtual/SettingsTableTest.java +++ b/test/unit/org/apache/cassandra/db/virtual/SettingsTableTest.java @@ -136,40 +136,40 @@ public void testEncryptionOverride() throws Throwable String all = "SELECT * FROM vts.settings WHERE " + "name > 'server_encryption' AND name < 'server_encryptionz' ALLOW FILTERING"; - config.server_encryption_options.enabled = true; + config.server_encryption_options = config.server_encryption_options.withEnabled(true); Assert.assertEquals(9, executeNet(all).all().size()); check(pre + "enabled", "true"); check(pre + "algorithm", null); - config.server_encryption_options.algorithm = "SUPERSSL"; + config.server_encryption_options = config.server_encryption_options.withAlgorithm("SUPERSSL"); check(pre + "algorithm", "SUPERSSL"); check(pre + "cipher_suites", "[]"); - config.server_encryption_options.cipher_suites = new String[]{"c1", "c2"}; + config.server_encryption_options = config.server_encryption_options.withCipherSuites("c1", "c2"); check(pre + "cipher_suites", "[c1, c2]"); check(pre + "protocol", config.server_encryption_options.protocol); - config.server_encryption_options.protocol = "TLSv5"; + config.server_encryption_options = config.server_encryption_options.withProtocol("TLSv5"); check(pre + "protocol", "TLSv5"); check(pre + "optional", "false"); - config.server_encryption_options.optional = true; + config.server_encryption_options = config.server_encryption_options.withOptional(true); check(pre + "optional", "true"); check(pre + "client_auth", "false"); - config.server_encryption_options.require_client_auth = true; + config.server_encryption_options = config.server_encryption_options.withRequireClientAuth(true); check(pre + "client_auth", "true"); check(pre + "endpoint_verification", "false"); - config.server_encryption_options.require_endpoint_verification = true; + config.server_encryption_options = config.server_encryption_options.withRequireEndpointVerification(true); check(pre + "endpoint_verification", "true"); check(pre + "internode_encryption", "none"); - config.server_encryption_options.internode_encryption = InternodeEncryption.all; + config.server_encryption_options = config.server_encryption_options.withInternodeEncryption(InternodeEncryption.all); check(pre + "internode_encryption", "all"); check(pre + "legacy_ssl_storage_port", "false"); - config.server_encryption_options.enable_legacy_ssl_storage_port = true; + config.server_encryption_options = config.server_encryption_options.withLegacySslStoragePort(true); check(pre + "legacy_ssl_storage_port", "true"); } diff --git a/test/unit/org/apache/cassandra/dht/BootStrapperTest.java b/test/unit/org/apache/cassandra/dht/BootStrapperTest.java index 2f412ad63793..c0b6d5cb70ea 100644 --- a/test/unit/org/apache/cassandra/dht/BootStrapperTest.java +++ b/test/unit/org/apache/cassandra/dht/BootStrapperTest.java @@ -174,6 +174,17 @@ public void testAllocateTokens() throws UnknownHostException allocateTokensForNode(vn, ks, tm, addr); } + @Test + public void testAllocateTokensLocalRf() throws UnknownHostException + { + int vn = 16; + int allocateTokensForLocalRf = 3; + TokenMetadata tm = new TokenMetadata(); + generateFakeEndpoints(tm, 10, vn); + InetAddressAndPort addr = FBUtilities.getBroadcastAddressAndPort(); + allocateTokensForNode(vn, allocateTokensForLocalRf, tm, addr); + } + public void testAllocateTokensNetworkStrategy(int rackCount, int replicas) throws UnknownHostException { IEndpointSnitch oldSnitch = DatabaseDescriptor.getEndpointSnitch(); @@ -243,6 +254,14 @@ private void allocateTokensForNode(int vn, String ks, TokenMetadata tm, InetAddr verifyImprovement(os, ns); } + private void allocateTokensForNode(int vn, int rf, TokenMetadata tm, InetAddressAndPort addr) + { + Collection tokens = BootStrapper.allocateTokens(tm, addr, rf, vn, 0); + assertEquals(vn, tokens.size()); + tm.updateNormalTokens(tokens, addr); + // SummaryStatistics is not implemented for `allocate_tokens_for_local_replication_factor` so can't be verified + } + private void verifyImprovement(SummaryStatistics os, SummaryStatistics ns) { if (ns.getStandardDeviation() > os.getStandardDeviation()) diff --git a/test/unit/org/apache/cassandra/gms/GossiperTest.java b/test/unit/org/apache/cassandra/gms/GossiperTest.java index 9c25b86739e3..97c577c2cd75 100644 --- a/test/unit/org/apache/cassandra/gms/GossiperTest.java +++ b/test/unit/org/apache/cassandra/gms/GossiperTest.java @@ -21,6 +21,7 @@ import java.net.InetAddress; import java.net.UnknownHostException; import java.util.ArrayList; +import java.util.Collections; import java.util.List; import java.util.UUID; @@ -117,33 +118,119 @@ public void testHaveVersion3Nodes() throws Exception public void testLargeGenerationJump() throws UnknownHostException, InterruptedException { Util.createInitialRing(ss, partitioner, endpointTokens, keyTokens, hosts, hostIds, 2); - InetAddressAndPort remoteHostAddress = hosts.get(1); + try + { + InetAddressAndPort remoteHostAddress = hosts.get(1); + + EndpointState initialRemoteState = Gossiper.instance.getEndpointStateForEndpoint(remoteHostAddress); + HeartBeatState initialRemoteHeartBeat = initialRemoteState.getHeartBeatState(); + + //Util.createInitialRing should have initialized remoteHost's HeartBeatState's generation to 1 + assertEquals(initialRemoteHeartBeat.getGeneration(), 1); + + HeartBeatState proposedRemoteHeartBeat = new HeartBeatState(initialRemoteHeartBeat.getGeneration() + Gossiper.MAX_GENERATION_DIFFERENCE + 1); + EndpointState proposedRemoteState = new EndpointState(proposedRemoteHeartBeat); + + Gossiper.instance.applyStateLocally(ImmutableMap.of(remoteHostAddress, proposedRemoteState)); + + //The generation should have been updated because it isn't over Gossiper.MAX_GENERATION_DIFFERENCE in the future + HeartBeatState actualRemoteHeartBeat = Gossiper.instance.getEndpointStateForEndpoint(remoteHostAddress).getHeartBeatState(); + assertEquals(proposedRemoteHeartBeat.getGeneration(), actualRemoteHeartBeat.getGeneration()); + + //Propose a generation 10 years in the future - this should be rejected. + HeartBeatState badProposedRemoteHeartBeat = new HeartBeatState((int) (System.currentTimeMillis() / 1000) + Gossiper.MAX_GENERATION_DIFFERENCE * 10); + EndpointState badProposedRemoteState = new EndpointState(badProposedRemoteHeartBeat); + + Gossiper.instance.applyStateLocally(ImmutableMap.of(remoteHostAddress, badProposedRemoteState)); + + actualRemoteHeartBeat = Gossiper.instance.getEndpointStateForEndpoint(remoteHostAddress).getHeartBeatState(); + + //The generation should not have been updated because it is over Gossiper.MAX_GENERATION_DIFFERENCE in the future + assertEquals(proposedRemoteHeartBeat.getGeneration(), actualRemoteHeartBeat.getGeneration()); + } + finally + { + // clean up the gossip states + Gossiper.instance.endpointStateMap.clear(); + } + } + + int stateChangedNum = 0; + + @Test + public void testDuplicatedStateUpdate() throws Exception + { + VersionedValue.VersionedValueFactory valueFactory = + new VersionedValue.VersionedValueFactory(DatabaseDescriptor.getPartitioner()); + + Util.createInitialRing(ss, partitioner, endpointTokens, keyTokens, hosts, hostIds, 2); + try + { + InetAddressAndPort remoteHostAddress = hosts.get(1); + + EndpointState initialRemoteState = Gossiper.instance.getEndpointStateForEndpoint(remoteHostAddress); + HeartBeatState initialRemoteHeartBeat = initialRemoteState.getHeartBeatState(); + + //Util.createInitialRing should have initialized remoteHost's HeartBeatState's generation to 1 + assertEquals(initialRemoteHeartBeat.getGeneration(), 1); + + HeartBeatState proposedRemoteHeartBeat = new HeartBeatState(initialRemoteHeartBeat.getGeneration()); + EndpointState proposedRemoteState = new EndpointState(proposedRemoteHeartBeat); - EndpointState initialRemoteState = Gossiper.instance.getEndpointStateForEndpoint(remoteHostAddress); - HeartBeatState initialRemoteHeartBeat = initialRemoteState.getHeartBeatState(); + final Token token = DatabaseDescriptor.getPartitioner().getRandomToken(); + VersionedValue tokensValue = valueFactory.tokens(Collections.singletonList(token)); + proposedRemoteState.addApplicationState(ApplicationState.TOKENS, tokensValue); - //Util.createInitialRing should have initialized remoteHost's HeartBeatState's generation to 1 - assertEquals(initialRemoteHeartBeat.getGeneration(), 1); + Gossiper.instance.register( + new IEndpointStateChangeSubscriber() + { + public void onJoin(InetAddressAndPort endpoint, EndpointState epState) { } - HeartBeatState proposedRemoteHeartBeat = new HeartBeatState(initialRemoteHeartBeat.getGeneration() + Gossiper.MAX_GENERATION_DIFFERENCE + 1); - EndpointState proposedRemoteState = new EndpointState(proposedRemoteHeartBeat); + public void beforeChange(InetAddressAndPort endpoint, EndpointState currentState, ApplicationState newStateKey, VersionedValue newValue) { } - Gossiper.instance.applyStateLocally(ImmutableMap.of(remoteHostAddress, proposedRemoteState)); + public void onChange(InetAddressAndPort endpoint, ApplicationState state, VersionedValue value) + { + assertEquals(ApplicationState.TOKENS, state); + stateChangedNum++; + } - //The generation should have been updated because it isn't over Gossiper.MAX_GENERATION_DIFFERENCE in the future - HeartBeatState actualRemoteHeartBeat = Gossiper.instance.getEndpointStateForEndpoint(remoteHostAddress).getHeartBeatState(); - assertEquals(proposedRemoteHeartBeat.getGeneration(), actualRemoteHeartBeat.getGeneration()); + public void onAlive(InetAddressAndPort endpoint, EndpointState state) { } - //Propose a generation 10 years in the future - this should be rejected. - HeartBeatState badProposedRemoteHeartBeat = new HeartBeatState((int) (System.currentTimeMillis()/1000) + Gossiper.MAX_GENERATION_DIFFERENCE * 10); - EndpointState badProposedRemoteState = new EndpointState(badProposedRemoteHeartBeat); + public void onDead(InetAddressAndPort endpoint, EndpointState state) { } - Gossiper.instance.applyStateLocally(ImmutableMap.of(remoteHostAddress, badProposedRemoteState)); + public void onRemove(InetAddressAndPort endpoint) { } - actualRemoteHeartBeat = Gossiper.instance.getEndpointStateForEndpoint(remoteHostAddress).getHeartBeatState(); + public void onRestart(InetAddressAndPort endpoint, EndpointState state) { } + } + ); - //The generation should not have been updated because it is over Gossiper.MAX_GENERATION_DIFFERENCE in the future - assertEquals(proposedRemoteHeartBeat.getGeneration(), actualRemoteHeartBeat.getGeneration()); + stateChangedNum = 0; + Gossiper.instance.applyStateLocally(ImmutableMap.of(remoteHostAddress, proposedRemoteState)); + assertEquals(1, stateChangedNum); + + HeartBeatState actualRemoteHeartBeat = Gossiper.instance.getEndpointStateForEndpoint(remoteHostAddress).getHeartBeatState(); + assertEquals(proposedRemoteHeartBeat.getGeneration(), actualRemoteHeartBeat.getGeneration()); + + // Clone a new HeartBeatState + proposedRemoteHeartBeat = new HeartBeatState(initialRemoteHeartBeat.getGeneration(), proposedRemoteHeartBeat.getHeartBeatVersion()); + proposedRemoteState = new EndpointState(proposedRemoteHeartBeat); + + // Bump the heartbeat version and use the same TOKENS state + proposedRemoteHeartBeat.updateHeartBeat(); + proposedRemoteState.addApplicationState(ApplicationState.TOKENS, tokensValue); + + // The following state change should only update heartbeat without updating the TOKENS state + Gossiper.instance.applyStateLocally(ImmutableMap.of(remoteHostAddress, proposedRemoteState)); + assertEquals(1, stateChangedNum); + + actualRemoteHeartBeat = Gossiper.instance.getEndpointStateForEndpoint(remoteHostAddress).getHeartBeatState(); + assertEquals(proposedRemoteHeartBeat.getGeneration(), actualRemoteHeartBeat.getGeneration()); + } + finally + { + // clean up the gossip states + Gossiper.instance.endpointStateMap.clear(); + } } // Note: This test might fail if for some reason the node broadcast address is in 127.99.0.0/16 diff --git a/test/unit/org/apache/cassandra/gms/ShadowRoundTest.java b/test/unit/org/apache/cassandra/gms/ShadowRoundTest.java index f8cc49cd645e..57cd4a9f37b8 100644 --- a/test/unit/org/apache/cassandra/gms/ShadowRoundTest.java +++ b/test/unit/org/apache/cassandra/gms/ShadowRoundTest.java @@ -33,10 +33,10 @@ import org.apache.cassandra.exceptions.ConfigurationException; import org.apache.cassandra.locator.IEndpointSnitch; import org.apache.cassandra.locator.PropertyFileSnitch; -import org.apache.cassandra.net.MessageIn; -import org.apache.cassandra.net.MessagingService; +import org.apache.cassandra.net.Message; import org.apache.cassandra.net.MockMessagingService; import org.apache.cassandra.net.MockMessagingSpy; +import org.apache.cassandra.net.Verb; import org.apache.cassandra.service.StorageService; import static org.apache.cassandra.net.MockMessagingService.verb; @@ -71,11 +71,11 @@ public void testDelayedResponse() int noOfSeeds = Gossiper.instance.seeds.size(); final AtomicBoolean ackSend = new AtomicBoolean(false); - MockMessagingSpy spySyn = MockMessagingService.when(verb(MessagingService.Verb.GOSSIP_DIGEST_SYN)) + MockMessagingSpy spySyn = MockMessagingService.when(verb(Verb.GOSSIP_DIGEST_SYN)) .respondN((msgOut, to) -> { // ACK once to finish shadow round, then busy-spin until gossiper has been enabled - // and then reply with remaining ACKs from other seeds + // and then respond with remaining ACKs from other seeds if (!ackSend.compareAndSet(false, true)) { while (!Gossiper.instance.isEnabled()) ; @@ -87,15 +87,17 @@ public void testDelayedResponse() Collections.singletonList(new GossipDigest(to, hb.getGeneration(), hb.getHeartBeatVersion())), Collections.singletonMap(to, state)); - logger.debug("Simulating digest ACK reply"); - return MessageIn.create(to, payload, Collections.emptyMap(), MessagingService.Verb.GOSSIP_DIGEST_ACK, MessagingService.current_version); + logger.debug("Simulating digest ACK response"); + return Message.builder(Verb.GOSSIP_DIGEST_ACK, payload) + .from(to) + .build(); }, noOfSeeds); // GossipDigestAckVerbHandler will send ack2 for each ack received (after the shadow round) - MockMessagingSpy spyAck2 = MockMessagingService.when(verb(MessagingService.Verb.GOSSIP_DIGEST_ACK2)).dontReply(); + MockMessagingSpy spyAck2 = MockMessagingService.when(verb(Verb.GOSSIP_DIGEST_ACK2)).dontReply(); // Migration request messages should not be emitted during shadow round - MockMessagingSpy spyMigrationReq = MockMessagingService.when(verb(MessagingService.Verb.MIGRATION_REQUEST)).dontReply(); + MockMessagingSpy spyMigrationReq = MockMessagingService.when(verb(Verb.SCHEMA_PULL_REQ)).dontReply(); try { @@ -109,7 +111,7 @@ public void testDelayedResponse() // we expect one SYN for each seed during shadow round + additional SYNs after gossiper has been enabled assertTrue(spySyn.messagesIntercepted > noOfSeeds); - // we don't expect to emit any GOSSIP_DIGEST_ACK2 or MIGRATION_REQUEST messages + // we don't expect to emit any GOSSIP_DIGEST_ACK2 or SCHEMA_PULL messages assertEquals(0, spyAck2.messagesIntercepted); assertEquals(0, spyMigrationReq.messagesIntercepted); } diff --git a/test/unit/org/apache/cassandra/hints/HintTest.java b/test/unit/org/apache/cassandra/hints/HintTest.java index aac975b0c220..e3e26d0d7c24 100644 --- a/test/unit/org/apache/cassandra/hints/HintTest.java +++ b/test/unit/org/apache/cassandra/hints/HintTest.java @@ -41,13 +41,12 @@ import org.apache.cassandra.locator.InetAddressAndPort; import org.apache.cassandra.locator.TokenMetadata; import org.apache.cassandra.metrics.StorageMetrics; -import org.apache.cassandra.net.MessageIn; +import org.apache.cassandra.net.Message; import org.apache.cassandra.net.MessagingService; import org.apache.cassandra.schema.ColumnMetadata; import org.apache.cassandra.schema.KeyspaceParams; import org.apache.cassandra.schema.Schema; import org.apache.cassandra.schema.TableMetadata; -import org.apache.cassandra.schema.TableParams; import org.apache.cassandra.schema.MigrationManager; import org.apache.cassandra.service.StorageProxy; import org.apache.cassandra.service.StorageService; @@ -58,6 +57,7 @@ import static org.apache.cassandra.Util.dk; import static org.apache.cassandra.hints.HintsTestUtil.assertHintsEqual; import static org.apache.cassandra.hints.HintsTestUtil.assertPartitionsEqual; +import static org.apache.cassandra.net.Verb.HINT_REQ; public class HintTest { @@ -246,9 +246,7 @@ public void testChangedTopology() throws Exception long totalHintCount = StorageProxy.instance.getTotalHints(); // Process hint message. HintMessage message = new HintMessage(localId, hint); - MessagingService.instance().getVerbHandler(MessagingService.Verb.HINT).doVerb( - MessageIn.create(local, message, Collections.emptyMap(), MessagingService.Verb.HINT, MessagingService.current_version), - -1); + HINT_REQ.handler().doVerb(Message.out(HINT_REQ, message)); // hint should not be applied as we no longer are a replica assertNoPartitions(key, TABLE0); @@ -291,9 +289,8 @@ public void testChangedTopologyNotHintable() throws Exception long totalHintCount = StorageMetrics.totalHints.getCount(); // Process hint message. HintMessage message = new HintMessage(localId, hint); - MessagingService.instance().getVerbHandler(MessagingService.Verb.HINT).doVerb( - MessageIn.create(local, message, Collections.emptyMap(), MessagingService.Verb.HINT, MessagingService.current_version), - -1); + HINT_REQ.handler().doVerb( + Message.builder(HINT_REQ, message).from(local).build()); // hint should not be applied as we no longer are a replica assertNoPartitions(key, TABLE0); diff --git a/test/unit/org/apache/cassandra/hints/HintsServiceTest.java b/test/unit/org/apache/cassandra/hints/HintsServiceTest.java index b71140f09380..77783311c794 100644 --- a/test/unit/org/apache/cassandra/hints/HintsServiceTest.java +++ b/test/unit/org/apache/cassandra/hints/HintsServiceTest.java @@ -17,7 +17,6 @@ */ package org.apache.cassandra.hints; -import java.util.Collections; import java.util.UUID; import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; @@ -34,6 +33,7 @@ import com.datastax.driver.core.utils.MoreFutures; import org.apache.cassandra.SchemaLoader; import org.apache.cassandra.locator.InetAddressAndPort; +import org.apache.cassandra.net.NoPayload; import org.apache.cassandra.schema.TableMetadata; import org.apache.cassandra.schema.Schema; import org.apache.cassandra.db.DecoratedKey; @@ -41,15 +41,16 @@ import org.apache.cassandra.gms.IFailureDetectionEventListener; import org.apache.cassandra.gms.IFailureDetector; import org.apache.cassandra.metrics.StorageMetrics; -import org.apache.cassandra.net.MessageIn; +import org.apache.cassandra.net.Message; import org.apache.cassandra.net.MessagingService; import org.apache.cassandra.net.MockMessagingService; import org.apache.cassandra.net.MockMessagingSpy; import org.apache.cassandra.schema.KeyspaceParams; import org.apache.cassandra.service.StorageService; -import org.apache.cassandra.utils.FBUtilities; import static org.apache.cassandra.Util.dk; +import static org.apache.cassandra.net.Verb.HINT_REQ; +import static org.apache.cassandra.net.Verb.HINT_RSP; import static org.apache.cassandra.net.MockMessagingService.verb; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; @@ -80,7 +81,8 @@ public void cleanup() @Before public void reinstanciateService() throws Throwable { - MessagingService.instance().clearMessageSinks(); + MessagingService.instance().inboundSink.clear(); + MessagingService.instance().outboundSink.clear(); if (!HintsService.instance.isShutDown()) { @@ -182,20 +184,16 @@ public void testPageSeek() throws InterruptedException, ExecutionException private MockMessagingSpy sendHintsAndResponses(int noOfHints, int noOfResponses) { // create spy for hint messages, but only create responses for noOfResponses hints - MessageIn messageIn = MessageIn.create(FBUtilities.getBroadcastAddressAndPort(), - HintResponse.instance, - Collections.emptyMap(), - MessagingService.Verb.REQUEST_RESPONSE, - MessagingService.current_version); + Message message = Message.internalResponse(HINT_RSP, NoPayload.noPayload); MockMessagingSpy spy; if (noOfResponses != -1) { - spy = MockMessagingService.when(verb(MessagingService.Verb.HINT)).respondN(messageIn, noOfResponses); + spy = MockMessagingService.when(verb(HINT_REQ)).respondN(message, noOfResponses); } else { - spy = MockMessagingService.when(verb(MessagingService.Verb.HINT)).respond(messageIn); + spy = MockMessagingService.when(verb(HINT_REQ)).respond(message); } // create and write noOfHints using service diff --git a/test/unit/org/apache/cassandra/index/CustomIndexTest.java b/test/unit/org/apache/cassandra/index/CustomIndexTest.java index 35e0353f82d3..dfab35df400b 100644 --- a/test/unit/org/apache/cassandra/index/CustomIndexTest.java +++ b/test/unit/org/apache/cassandra/index/CustomIndexTest.java @@ -639,7 +639,7 @@ public void testFailing2iFlush() throws Throwable try { - getCurrentColumnFamilyStore().forceBlockingFlush(); + getCurrentColumnFamilyStore().forceBlockingFlushToSSTable(); fail("Exception should have been propagated"); } catch (Throwable t) @@ -661,7 +661,7 @@ public void indexBuildingPagesLargePartitions() throws Throwable // Insert a single wide partition to be indexed for (int i = 0; i < totalRows; i++) execute("INSERT INTO %s (k, c, v) VALUES (0, ?, ?)", i, i); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); // Create the index, which won't automatically start building String indexName = "build_single_partition_idx"; @@ -714,7 +714,7 @@ public void partitionIndexTest() throws Throwable execute("INSERT INTO %s (k, c, v) VALUES (?, ?, ?)", 5, 3, 3); execute("DELETE FROM %s WHERE k = ?", 5); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); String indexName = "partition_index_test_idx"; createIndex(String.format("CREATE CUSTOM INDEX %s ON %%s(v) USING '%s'", @@ -776,7 +776,7 @@ public void partitionIsNotOverIndexed() throws Throwable // Insert a single row partition to be indexed for (int i = 0; i < totalRows; i++) execute("INSERT INTO %s (k, c, v) VALUES (0, ?, ?)", i, i); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); // Create the index, which won't automatically start building String indexName = "partition_overindex_test_idx"; @@ -802,7 +802,7 @@ public void rangeTombstoneTest() throws Throwable // Insert a single range tombstone execute("DELETE FROM %s WHERE k=1 and c > 2"); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); // Create the index, which won't automatically start building String indexName = "range_tombstone_idx"; diff --git a/test/unit/org/apache/cassandra/index/internal/CustomCassandraIndex.java b/test/unit/org/apache/cassandra/index/internal/CustomCassandraIndex.java index 1c07760d966f..3e43f11ef46a 100644 --- a/test/unit/org/apache/cassandra/index/internal/CustomCassandraIndex.java +++ b/test/unit/org/apache/cassandra/index/internal/CustomCassandraIndex.java @@ -59,7 +59,6 @@ import org.apache.cassandra.schema.IndexMetadata; import org.apache.cassandra.utils.FBUtilities; import org.apache.cassandra.utils.Pair; -import org.apache.cassandra.utils.concurrent.OpOrder; import org.apache.cassandra.utils.concurrent.Refs; import static org.apache.cassandra.index.internal.CassandraIndex.getFunctions; @@ -136,7 +135,7 @@ public Optional getBackingTable() public Callable getBlockingFlushTask() { return () -> { - indexCfs.forceBlockingFlush(); + indexCfs.forceBlockingFlushToSSTable(); return null; }; } @@ -207,7 +206,7 @@ private boolean supportsExpression(RowFilter.Expression expression) public long getEstimatedResultRows() { - return indexCfs.getMeanColumns(); + return indexCfs.getMeanEstimatedCellPerPartitionCount(); } /** @@ -598,7 +597,7 @@ private void invalidate() CompactionManager.instance.interruptCompactionForCFs(cfss, (sstable) -> true, true); CompactionManager.instance.waitForCessation(cfss, (sstable) -> true); indexCfs.keyspace.writeOrder.awaitNewBarrier(); - indexCfs.forceBlockingFlush(); + indexCfs.forceBlockingFlushToSSTable(); indexCfs.readOrdering.awaitNewBarrier(); indexCfs.invalidate(); } @@ -623,7 +622,7 @@ private Callable getBuildIndexTask() private void buildBlocking() { - baseCfs.forceBlockingFlush(); + baseCfs.forceBlockingFlushToSSTable(); try (ColumnFamilyStore.RefViewFragment viewFragment = baseCfs.selectAndReference(View.selectFunction(SSTableSet.CANONICAL)); Refs sstables = viewFragment.refs) @@ -647,7 +646,7 @@ private void buildBlocking() ImmutableSet.copyOf(sstables)); Future future = CompactionManager.instance.submitIndexBuild(builder); FBUtilities.waitOnFuture(future); - indexCfs.forceBlockingFlush(); + indexCfs.forceBlockingFlushToSSTable(); } logger.info("Index build of {} complete", metadata.name); } diff --git a/test/unit/org/apache/cassandra/index/sasi/SASIIndexTest.java b/test/unit/org/apache/cassandra/index/sasi/SASIIndexTest.java index 90a59ddca6b1..ed249f624e7b 100644 --- a/test/unit/org/apache/cassandra/index/sasi/SASIIndexTest.java +++ b/test/unit/org/apache/cassandra/index/sasi/SASIIndexTest.java @@ -88,6 +88,8 @@ import org.junit.*; +import static java.util.concurrent.TimeUnit.MILLISECONDS; + public class SASIIndexTest { private static final IPartitioner PARTITIONER; @@ -451,7 +453,7 @@ private void testPrefixSearchWithContainsMode(boolean forceFlush) throws Excepti if (forceFlush) - store.forceBlockingFlush(); + store.forceBlockingFlushToSSTable(); final UntypedResultSet results = executeCQL(FTS_CF_NAME, "SELECT * FROM %s.%s WHERE artist LIKE 'lady%%'"); Assert.assertNotNull(results); @@ -805,7 +807,7 @@ private void testColumnNamesWithSlashes(boolean forceFlush) throws Exception rm3.build().apply(); if (forceFlush) - store.forceBlockingFlush(); + store.forceBlockingFlushToSSTable(); final ByteBuffer dataOutputId = UTF8Type.instance.decompose("/data/output/id"); @@ -967,7 +969,7 @@ private void redistributeSummaries(int expected, ColumnFamilyStore store, ByteBu { setMinIndexInterval(minIndexInterval); IndexSummaryManager.instance.redistributeSummaries(); - store.forceBlockingFlush(); + store.forceBlockingFlushToSSTable(); Set rows = getIndexed(store, 100, buildExpression(firstName, Operator.LIKE_CONTAINS, UTF8Type.instance.decompose("a"))); Assert.assertEquals(rows.toString(), expected, rows.size()); @@ -1187,7 +1189,7 @@ public void testInsertingIncorrectValuesIntoAgeIndex() }}); rm.build().apply(); - store.forceBlockingFlush(); + store.forceBlockingFlushToSSTable(); Set rows = getIndexed(store, 10, buildExpression(firstName, Operator.EQ, UTF8Type.instance.decompose("a")), buildExpression(age, Operator.GTE, Int32Type.instance.decompose(26))); @@ -1232,7 +1234,7 @@ private void testUnicodeSupport(boolean forceFlush) rm.build().apply(); if (forceFlush) - store.forceBlockingFlush(); + store.forceBlockingFlushToSSTable(); Set rows; @@ -1304,7 +1306,7 @@ private void testUnicodeSuffixModeNoSplits(boolean forceFlush) rm.build().apply(); if (forceFlush) - store.forceBlockingFlush(); + store.forceBlockingFlushToSSTable(); Set rows; @@ -1365,7 +1367,7 @@ public void testThatTooBigValueIsRejected() rows = getIndexed(store, 10, buildExpression(comment, Operator.LIKE_MATCHES, bigValue.duplicate())); Assert.assertEquals(0, rows.size()); - store.forceBlockingFlush(); + store.forceBlockingFlushToSSTable(); rows = getIndexed(store, 10, buildExpression(comment, Operator.LIKE_MATCHES, bigValue.duplicate())); Assert.assertEquals(0, rows.size()); @@ -1416,7 +1418,7 @@ public void testSearchTimeouts() throws Exception try (ReadExecutionController controller = command.executionController()) { - Set rows = getKeys(new QueryPlan(store, command, DatabaseDescriptor.getRangeRpcTimeout()).execute(controller)); + Set rows = getKeys(new QueryPlan(store, command, DatabaseDescriptor.getRangeRpcTimeout(MILLISECONDS)).execute(controller)); Assert.assertTrue(rows.toString(), Arrays.equals(new String[] { "key1", "key2", "key3", "key4" }, rows.toArray(new String[rows.size()]))); } } @@ -1468,7 +1470,7 @@ public void testChinesePrefixSearch() update(rm, fullName, UTF8Type.instance.decompose("利久 寺地"), System.currentTimeMillis()); rm.build().apply(); - store.forceBlockingFlush(); + store.forceBlockingFlushToSSTable(); Set rows; @@ -1505,7 +1507,7 @@ public void testLowerCaseAnalyzer(boolean forceFlush) rm.build().apply(); if (forceFlush) - store.forceBlockingFlush(); + store.forceBlockingFlushToSSTable(); Set rows; @@ -1590,7 +1592,7 @@ public void testPrefixSSTableLookup() rm.build().apply(); // first flush would make interval for name - 'johnny' -> 'pavel' - store.forceBlockingFlush(); + store.forceBlockingFlushToSSTable(); rm = new Mutation.PartitionUpdateCollector(KS_NAME, decoratedKey("key6")); update(rm, name, UTF8Type.instance.decompose("Jason"), System.currentTimeMillis()); @@ -1605,7 +1607,7 @@ public void testPrefixSSTableLookup() rm.build().apply(); // this flush is going to produce range - 'jason' -> 'vijay' - store.forceBlockingFlush(); + store.forceBlockingFlushToSSTable(); // make sure that overlap of the prefixes is properly handled across sstables // since simple interval tree lookup is not going to cover it, prefix lookup actually required. @@ -1769,7 +1771,7 @@ public void testClusteringIndexes(boolean forceFlush) throws Exception executeCQL(CLUSTERING_CF_NAME_1 ,"INSERT INTO %s.%s (name, nickname, location, age, height, score) VALUES (?, ?, ?, ?, ?, ?)", "Jordan", "jrwest", "US", 27, 182, 1.0); if (forceFlush) - store.forceBlockingFlush(); + store.forceBlockingFlushToSSTable(); UntypedResultSet results; @@ -1856,7 +1858,7 @@ public void testClusteringIndexes(boolean forceFlush) throws Exception executeCQL(CLUSTERING_CF_NAME_2 ,"INSERT INTO %s.%s (name, nickname, location, age, height, score) VALUES (?, ?, ?, ?, ?, ?)", "Christopher", "chis", "US", 27, 180, 1.0); if (forceFlush) - store.forceBlockingFlush(); + store.forceBlockingFlushToSSTable(); results = executeCQL(CLUSTERING_CF_NAME_2 ,"SELECT * FROM %s.%s WHERE location LIKE 'US' AND age = 43 ALLOW FILTERING"); Assert.assertNotNull(results); @@ -1882,7 +1884,7 @@ public void testStaticIndex(boolean shouldFlush) throws Exception executeCQL(STATIC_CF_NAME, "INSERT INTO %s.%s (sensor_id,date,value,variance) VALUES(?, ?, ?, ?)", 1, 20160403L, 24.96, 4); if (shouldFlush) - store.forceBlockingFlush(); + store.forceBlockingFlushToSSTable(); executeCQL(STATIC_CF_NAME, "INSERT INTO %s.%s (sensor_id,sensor_type) VALUES(?, ?)", 2, "PRESSURE"); executeCQL(STATIC_CF_NAME, "INSERT INTO %s.%s (sensor_id,date,value,variance) VALUES(?, ?, ?, ?)", 2, 20160401L, 1.03, 9); @@ -1890,7 +1892,7 @@ public void testStaticIndex(boolean shouldFlush) throws Exception executeCQL(STATIC_CF_NAME, "INSERT INTO %s.%s (sensor_id,date,value,variance) VALUES(?, ?, ?, ?)", 2, 20160403L, 1.01, 4); if (shouldFlush) - store.forceBlockingFlush(); + store.forceBlockingFlushToSSTable(); UntypedResultSet results; @@ -1971,7 +1973,7 @@ public void testTableRebuild() throws Exception executeCQL(CLUSTERING_CF_NAME_1, "INSERT INTO %s.%s (name, location, age, height, score) VALUES (?, ?, ?, ?, ?)", "Pavel", "BY", 28, 182, 2.0); executeCQL(CLUSTERING_CF_NAME_1, "INSERT INTO %s.%s (name, nickname, location, age, height, score) VALUES (?, ?, ?, ?, ?, ?)", "Jordan", "jrwest", "US", 27, 182, 1.0); - store.forceBlockingFlush(); + store.forceBlockingFlushToSSTable(); SSTable ssTable = store.getSSTables(SSTableSet.LIVE).iterator().next(); Path path = FileSystems.getDefault().getPath(ssTable.getFilename().replace("-Data", "-SI_" + CLUSTERING_CF_NAME_1 + "_age")); @@ -2008,7 +2010,7 @@ public void testIndexRebuild() throws Exception executeCQL(CLUSTERING_CF_NAME_1, "INSERT INTO %s.%s (name, nickname) VALUES (?, ?)", "Alex", "ifesdjeen"); - store.forceBlockingFlush(); + store.forceBlockingFlushToSSTable(); for (Index index : store.indexManager.listIndexes()) { @@ -2134,7 +2136,7 @@ private void testLIKEAndEQSemanticsWithDifferenceKindsOfIndexes(String containsT { Keyspace keyspace = Keyspace.open(KS_NAME); for (String table : Arrays.asList(containsTable, prefixTable, analyzedPrefixTable)) - keyspace.getColumnFamilyStore(table).forceBlockingFlush(); + keyspace.getColumnFamilyStore(table).forceBlockingFlushToSSTable(); } UntypedResultSet results; @@ -2368,7 +2370,7 @@ public void testIndexMemtableSwitching() Assert.assertTrue(beforeFlushMemtable.search(expression).getCount() > 0); - store.forceBlockingFlush(); + store.forceBlockingFlushToSSTable(); IndexMemtable afterFlushMemtable = index.getCurrentMemtable(); @@ -2499,7 +2501,7 @@ private static ColumnFamilyStore loadData(Map> dat ColumnFamilyStore store = Keyspace.open(KS_NAME).getColumnFamilyStore(CF_NAME); if (forceFlush) - store.forceBlockingFlush(); + store.forceBlockingFlushToSSTable(); return store; } diff --git a/test/unit/org/apache/cassandra/io/compress/CompressedSequentialWriterReopenTest.java b/test/unit/org/apache/cassandra/io/compress/CompressedSequentialWriterReopenTest.java index 461c13cd09ee..f02820b7badc 100644 --- a/test/unit/org/apache/cassandra/io/compress/CompressedSequentialWriterReopenTest.java +++ b/test/unit/org/apache/cassandra/io/compress/CompressedSequentialWriterReopenTest.java @@ -78,12 +78,12 @@ public void compressionEnabled() throws Throwable { execute("insert into %s (id, t) values (?, ?)", i, ByteBuffer.wrap(blob)); } - getCurrentColumnFamilyStore().forceBlockingFlush(); + getCurrentColumnFamilyStore().forceBlockingFlushToSSTable(); for (int i = 0; i < 10000; i++) { execute("insert into %s (id, t) values (?, ?)", i, ByteBuffer.wrap(blob)); } - getCurrentColumnFamilyStore().forceBlockingFlush(); + getCurrentColumnFamilyStore().forceBlockingFlushToSSTable(); DatabaseDescriptor.setSSTablePreemptiveOpenIntervalInMB(1); getCurrentColumnFamilyStore().forceMajorCompaction(); } diff --git a/test/unit/org/apache/cassandra/io/compress/CompressorTest.java b/test/unit/org/apache/cassandra/io/compress/CompressorTest.java index b649f526d9e4..d45d94165785 100644 --- a/test/unit/org/apache/cassandra/io/compress/CompressorTest.java +++ b/test/unit/org/apache/cassandra/io/compress/CompressorTest.java @@ -82,7 +82,7 @@ public void testArrayUncompress(byte[] data, int off, int len) throws IOExceptio // need byte[] representation which direct buffers don't have byte[] compressedBytes = new byte[compressed.capacity()]; - ByteBufferUtil.arrayCopy(compressed, outOffset, compressedBytes, outOffset, compressed.limit() - outOffset); + ByteBufferUtil.copyBytes(compressed, outOffset, compressedBytes, outOffset, compressed.limit() - outOffset); final int decompressedLength = compressor.uncompress(compressedBytes, outOffset, compressed.remaining(), restored, restoreOffset); diff --git a/test/unit/org/apache/cassandra/io/sstable/IndexSummaryManagerTest.java b/test/unit/org/apache/cassandra/io/sstable/IndexSummaryManagerTest.java index 68ee3e1931bc..eb98643cadcb 100644 --- a/test/unit/org/apache/cassandra/io/sstable/IndexSummaryManagerTest.java +++ b/test/unit/org/apache/cassandra/io/sstable/IndexSummaryManagerTest.java @@ -23,7 +23,6 @@ import java.util.concurrent.*; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; -import java.util.stream.Collectors; import com.google.common.base.Joiner; import com.google.common.collect.Sets; @@ -42,7 +41,6 @@ import org.apache.cassandra.db.ColumnFamilyStore; import org.apache.cassandra.db.Keyspace; import org.apache.cassandra.db.RowUpdateBuilder; -import org.apache.cassandra.db.compaction.AntiCompactionTest; import org.apache.cassandra.db.compaction.CompactionInfo; import org.apache.cassandra.db.compaction.CompactionInterruptedException; import org.apache.cassandra.db.compaction.CompactionManager; @@ -51,7 +49,6 @@ import org.apache.cassandra.db.lifecycle.LifecycleTransaction; import org.apache.cassandra.exceptions.ConfigurationException; import org.apache.cassandra.io.sstable.format.SSTableReader; -import org.apache.cassandra.metrics.CompactionMetrics; import org.apache.cassandra.metrics.RestorableMeter; import org.apache.cassandra.schema.CachingParams; import org.apache.cassandra.schema.KeyspaceParams; @@ -194,7 +191,7 @@ private void createSSTables(String ksname, String cfname, int numSSTables, int n .build() .applyUnsafe(); } - futures.add(cfs.forceFlush()); + futures.add(cfs.forceFlushToSSTable()); } for (Future future : futures) { @@ -519,7 +516,7 @@ public void testRebuildAtSamplingLevel() throws IOException .applyUnsafe(); } - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); List sstables = new ArrayList<>(cfs.getLiveSSTables()); assertEquals(1, sstables.size()); @@ -584,7 +581,7 @@ public void testJMXFunctions() throws IOException .build() .applyUnsafe(); } - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); } assertTrue(manager.getAverageIndexInterval() >= cfs.metadata().params.minIndexInterval); diff --git a/test/unit/org/apache/cassandra/io/sstable/IndexSummaryRedistributionTest.java b/test/unit/org/apache/cassandra/io/sstable/IndexSummaryRedistributionTest.java index 07a2212e8f9d..919e873a9009 100644 --- a/test/unit/org/apache/cassandra/io/sstable/IndexSummaryRedistributionTest.java +++ b/test/unit/org/apache/cassandra/io/sstable/IndexSummaryRedistributionTest.java @@ -125,7 +125,7 @@ private void createSSTables(String ksname, String cfname, int numSSTables, int n .build() .applyUnsafe(); } - futures.add(cfs.forceFlush()); + futures.add(cfs.forceFlushToSSTable()); } for (Future future : futures) { diff --git a/test/unit/org/apache/cassandra/io/sstable/SSTableCorruptionDetectionTest.java b/test/unit/org/apache/cassandra/io/sstable/SSTableCorruptionDetectionTest.java index 2510c5e8a487..a44f6921056c 100644 --- a/test/unit/org/apache/cassandra/io/sstable/SSTableCorruptionDetectionTest.java +++ b/test/unit/org/apache/cassandra/io/sstable/SSTableCorruptionDetectionTest.java @@ -117,7 +117,7 @@ public static void setUp() .add("reg2", ByteBuffer.wrap(reg2)); writer.append(builder.build().unfilteredIterator()); } - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); ssTableReader = writer.finish(true); txn.update(ssTableReader, false); diff --git a/test/unit/org/apache/cassandra/io/sstable/SSTableLoaderTest.java b/test/unit/org/apache/cassandra/io/sstable/SSTableLoaderTest.java index 5d40f8cc6218..13cdfd38fd23 100644 --- a/test/unit/org/apache/cassandra/io/sstable/SSTableLoaderTest.java +++ b/test/unit/org/apache/cassandra/io/sstable/SSTableLoaderTest.java @@ -138,7 +138,7 @@ public void testLoadingSSTable() throws Exception } ColumnFamilyStore cfs = Keyspace.open(KEYSPACE1).getColumnFamilyStore(CF_STANDARD1); - cfs.forceBlockingFlush(); // wait for sstables to be on disk else we won't be able to stream them + cfs.forceBlockingFlushToSSTable(); // wait for sstables to be on disk else we won't be able to stream them final CountDownLatch latch = new CountDownLatch(1); SSTableLoader loader = new SSTableLoader(dataDir, new TestClient(), new OutputHandler.SystemOutput(false, false)); @@ -185,7 +185,7 @@ public void testLoadingIncompleteSSTable() throws Exception } ColumnFamilyStore cfs = Keyspace.open(KEYSPACE1).getColumnFamilyStore(CF_STANDARD2); - cfs.forceBlockingFlush(); // wait for sstables to be on disk else we won't be able to stream them + cfs.forceBlockingFlushToSSTable(); // wait for sstables to be on disk else we won't be able to stream them //make sure we have some tables... assertTrue(dataDir.listFiles().length > 0); @@ -233,14 +233,14 @@ public void testLoadingSSTableToDifferentKeyspace() throws Exception } ColumnFamilyStore cfs = Keyspace.open(KEYSPACE1).getColumnFamilyStore(CF_STANDARD1); - cfs.forceBlockingFlush(); // wait for sstables to be on disk else we won't be able to stream them + cfs.forceBlockingFlushToSSTable(); // wait for sstables to be on disk else we won't be able to stream them final CountDownLatch latch = new CountDownLatch(1); SSTableLoader loader = new SSTableLoader(dataDir, new TestClient(), new OutputHandler.SystemOutput(false, false), 1, KEYSPACE2); loader.stream(Collections.emptySet(), completionStreamListener(latch)).get(); cfs = Keyspace.open(KEYSPACE2).getColumnFamilyStore(CF_STANDARD1); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); List partitions = Util.getAll(Util.cmd(cfs).build()); diff --git a/test/unit/org/apache/cassandra/io/sstable/SSTableMetadataTest.java b/test/unit/org/apache/cassandra/io/sstable/SSTableMetadataTest.java index 98b356a0ae75..3f103167f778 100644 --- a/test/unit/org/apache/cassandra/io/sstable/SSTableMetadataTest.java +++ b/test/unit/org/apache/cassandra/io/sstable/SSTableMetadataTest.java @@ -93,7 +93,7 @@ public void testTrackMaxDeletionTime() .applyUnsafe(); - store.forceBlockingFlush(); + store.forceBlockingFlushToSSTable(); assertEquals(1, store.getLiveSSTables().size()); int ttltimestamp = (int)(System.currentTimeMillis()/1000); int firstDelTime = 0; @@ -112,7 +112,7 @@ public void testTrackMaxDeletionTime() ttltimestamp = (int) (System.currentTimeMillis()/1000); - store.forceBlockingFlush(); + store.forceBlockingFlushToSSTable(); assertEquals(2, store.getLiveSSTables().size()); List sstables = new ArrayList<>(store.getLiveSSTables()); if(sstables.get(0).getSSTableMetadata().maxLocalDeletionTime < sstables.get(1).getSSTableMetadata().maxLocalDeletionTime) @@ -166,7 +166,7 @@ public void testWithDeletes() throws ExecutionException, InterruptedException .build() .applyUnsafe(); - store.forceBlockingFlush(); + store.forceBlockingFlushToSSTable(); assertEquals(1,store.getLiveSSTables().size()); int ttltimestamp = (int) (System.currentTimeMillis()/1000); int firstMaxDelTime = 0; @@ -178,7 +178,7 @@ public void testWithDeletes() throws ExecutionException, InterruptedException RowUpdateBuilder.deleteRow(store.metadata(), timestamp + 1, "deletetest", "todelete").applyUnsafe(); - store.forceBlockingFlush(); + store.forceBlockingFlushToSSTable(); assertEquals(2,store.getLiveSSTables().size()); boolean foundDelete = false; for(SSTableReader sstable : store.getLiveSSTables()) @@ -215,7 +215,7 @@ public void trackMaxMinColNames() throws CharacterCodingException, ExecutionExce .applyUnsafe(); } } - store.forceBlockingFlush(); + store.forceBlockingFlushToSSTable(); assertEquals(1, store.getLiveSSTables().size()); for (SSTableReader sstable : store.getLiveSSTables()) { @@ -233,7 +233,7 @@ public void trackMaxMinColNames() throws CharacterCodingException, ExecutionExce .applyUnsafe(); } - store.forceBlockingFlush(); + store.forceBlockingFlushToSSTable(); store.forceMajorCompaction(); assertEquals(1, store.getLiveSSTables().size()); for (SSTableReader sstable : store.getLiveSSTables()) diff --git a/test/unit/org/apache/cassandra/io/sstable/SSTableReaderTest.java b/test/unit/org/apache/cassandra/io/sstable/SSTableReaderTest.java index 580b099bda98..16b78ba12807 100644 --- a/test/unit/org/apache/cassandra/io/sstable/SSTableReaderTest.java +++ b/test/unit/org/apache/cassandra/io/sstable/SSTableReaderTest.java @@ -22,7 +22,6 @@ import java.nio.ByteBuffer; import java.nio.file.Files; import java.nio.file.Path; -import java.nio.file.attribute.FileAttribute; import java.util.*; import java.util.concurrent.*; @@ -111,7 +110,7 @@ public void testGetPositionsForRanges() .build() .applyUnsafe(); } - store.forceBlockingFlush(); + store.forceBlockingFlushToSSTable(); CompactionManager.instance.performMaximal(store, false); List> ranges = new ArrayList>(); @@ -157,7 +156,7 @@ public void testSpannedIndexPositions() throws IOException .build() .applyUnsafe(); } - store.forceBlockingFlush(); + store.forceBlockingFlushToSSTable(); CompactionManager.instance.performMaximal(store, false); // check that all our keys are found correctly @@ -199,7 +198,7 @@ public void testPersistentStatistics() .build() .applyUnsafe(); } - store.forceBlockingFlush(); + store.forceBlockingFlushToSSTable(); clearAndLoad(store); assert store.metric.maxPartitionSize.getValue() != 0; @@ -228,7 +227,7 @@ public void testReadRateTracking() .applyUnsafe(); } - store.forceBlockingFlush(); + store.forceBlockingFlushToSSTable(); SSTableReader sstable = store.getLiveSSTables().iterator().next(); assertEquals(0, sstable.getReadMeter().count()); @@ -261,7 +260,7 @@ public void testGetPositionsForRangesWithKeyCache() .applyUnsafe(); } - store.forceBlockingFlush(); + store.forceBlockingFlushToSSTable(); CompactionManager.instance.performMaximal(store, false); SSTableReader sstable = store.getLiveSSTables().iterator().next(); @@ -293,7 +292,7 @@ public void testPersistentStatisticsWithSecondaryIndex() .build() .applyUnsafe(); - store.forceBlockingFlush(); + store.forceBlockingFlushToSSTable(); // check if opening and querying works assertIndexQueryWorks(store); @@ -315,7 +314,7 @@ public void testGetPositionsKeyCacheStats() .build() .applyUnsafe(); } - store.forceBlockingFlush(); + store.forceBlockingFlushToSSTable(); CompactionManager.instance.performMaximal(store, false); SSTableReader sstable = store.getLiveSSTables().iterator().next(); @@ -363,7 +362,7 @@ public void testOpeningSSTable() throws Exception .build() .applyUnsafe(); } - store.forceBlockingFlush(); + store.forceBlockingFlushToSSTable(); SSTableReader sstable = store.getLiveSSTables().iterator().next(); Descriptor desc = sstable.descriptor; @@ -465,7 +464,7 @@ public void testLoadingSummaryUsesCorrectPartitioner() throws Exception .build() .applyUnsafe(); - store.forceBlockingFlush(); + store.forceBlockingFlushToSSTable(); for(ColumnFamilyStore indexCfs : store.indexManager.getAllIndexColumnFamilyStores()) { @@ -494,7 +493,7 @@ public void testGetScannerForNoIntersectingRanges() throws Exception .build() .applyUnsafe(); - store.forceBlockingFlush(); + store.forceBlockingFlushToSSTable(); boolean foundScanner = false; for (SSTableReader s : store.getLiveSSTables()) { @@ -528,7 +527,7 @@ public void testGetPositionsForRangesFromTableOpenedForBulkLoading() throws IOEx .applyUnsafe(); } - store.forceBlockingFlush(); + store.forceBlockingFlushToSSTable(); CompactionManager.instance.performMaximal(store, false); // construct a range which is present in the sstable, but whose @@ -567,7 +566,7 @@ public void testIndexSummaryReplacement() throws IOException, ExecutionException .applyUnsafe(); } - store.forceBlockingFlush(); + store.forceBlockingFlushToSSTable(); CompactionManager.instance.performMaximal(store, false); Collection sstables = store.getLiveSSTables(); @@ -646,7 +645,7 @@ private void testIndexSummaryUpsampleAndReload0() throws Exception .applyUnsafe(); } - store.forceBlockingFlush(); + store.forceBlockingFlushToSSTable(); CompactionManager.instance.performMaximal(store, false); Collection sstables = store.getLiveSSTables(); @@ -759,7 +758,7 @@ private SSTableReader getNewSSTable(ColumnFamilyStore cfs) .build() .applyUnsafe(); } - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); return Sets.difference(cfs.getLiveSSTables(), before).iterator().next(); } diff --git a/test/unit/org/apache/cassandra/io/sstable/SSTableRewriterTest.java b/test/unit/org/apache/cassandra/io/sstable/SSTableRewriterTest.java index 7c47c8b9beb0..5e562616beab 100644 --- a/test/unit/org/apache/cassandra/io/sstable/SSTableRewriterTest.java +++ b/test/unit/org/apache/cassandra/io/sstable/SSTableRewriterTest.java @@ -22,19 +22,15 @@ import java.io.IOException; import java.util.*; import java.util.concurrent.ExecutionException; -import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicInteger; import com.google.common.collect.Iterables; import com.google.common.collect.Sets; -import com.google.common.util.concurrent.Uninterruptibles; import org.junit.Test; import org.apache.cassandra.Util; import org.apache.cassandra.UpdateBuilder; import org.apache.cassandra.concurrent.NamedThreadFactory; -import org.apache.cassandra.config.DatabaseDescriptor; import org.apache.cassandra.db.ColumnFamilyStore; import org.apache.cassandra.db.DecoratedKey; import org.apache.cassandra.db.Keyspace; @@ -56,8 +52,6 @@ import org.apache.cassandra.io.util.FileUtils; import org.apache.cassandra.db.lifecycle.LifecycleTransaction; import org.apache.cassandra.metrics.StorageMetrics; -import org.apache.cassandra.streaming.PreviewKind; -import org.apache.cassandra.streaming.StreamSession; import org.apache.cassandra.utils.ByteBufferUtil; import org.apache.cassandra.utils.FBUtilities; import org.apache.cassandra.utils.UUIDGen; @@ -81,7 +75,7 @@ public void basicTest() throws InterruptedException .build() .apply(); } - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); Set sstables = new HashSet<>(cfs.getLiveSSTables()); assertEquals(1, sstables.size()); assertEquals(sstables.iterator().next().bytesOnDisk(), cfs.metric.liveDiskSpaceUsed.getCount()); @@ -698,7 +692,7 @@ public void testAllKeysReadable() throws Exception .build() .apply(); } - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); cfs.forceMajorCompaction(); validateKeys(keyspace); diff --git a/test/unit/org/apache/cassandra/io/sstable/SSTableScannerTest.java b/test/unit/org/apache/cassandra/io/sstable/SSTableScannerTest.java index eff95fccbb1c..46df324cc1e0 100644 --- a/test/unit/org/apache/cassandra/io/sstable/SSTableScannerTest.java +++ b/test/unit/org/apache/cassandra/io/sstable/SSTableScannerTest.java @@ -215,7 +215,7 @@ public void testSingleDataRange() throws IOException for (int i = 2; i < 10; i++) insertRowWithKey(store.metadata(), i); - store.forceBlockingFlush(); + store.forceBlockingFlushToSSTable(); assertEquals(1, store.getLiveSSTables().size()); SSTableReader sstable = store.getLiveSSTables().iterator().next(); @@ -321,7 +321,7 @@ public void testMultipleRanges() throws IOException for (int i = 0; i < 3; i++) for (int j = 2; j < 10; j++) insertRowWithKey(store.metadata(), i * 100 + j); - store.forceBlockingFlush(); + store.forceBlockingFlushToSSTable(); assertEquals(1, store.getLiveSSTables().size()); SSTableReader sstable = store.getLiveSSTables().iterator().next(); @@ -441,7 +441,7 @@ public void testSingleKeyMultipleRanges() throws IOException store.disableAutoCompaction(); insertRowWithKey(store.metadata(), 205); - store.forceBlockingFlush(); + store.forceBlockingFlushToSSTable(); assertEquals(1, store.getLiveSSTables().size()); SSTableReader sstable = store.getLiveSSTables().iterator().next(); diff --git a/test/unit/org/apache/cassandra/io/sstable/format/big/BigTableZeroCopyWriterTest.java b/test/unit/org/apache/cassandra/io/sstable/format/big/BigTableZeroCopyWriterTest.java index c3931e0ce8c5..1aef85103745 100644 --- a/test/unit/org/apache/cassandra/io/sstable/format/big/BigTableZeroCopyWriterTest.java +++ b/test/unit/org/apache/cassandra/io/sstable/format/big/BigTableZeroCopyWriterTest.java @@ -50,7 +50,7 @@ import org.apache.cassandra.io.sstable.format.SSTableReadsListener; import org.apache.cassandra.io.util.DataInputPlus; import org.apache.cassandra.io.util.FileHandle; -import org.apache.cassandra.net.async.RebufferingByteBufDataInputPlus; +import org.apache.cassandra.net.AsyncStreamingInputPlus; import org.apache.cassandra.schema.CachingParams; import org.apache.cassandra.schema.KeyspaceParams; import org.apache.cassandra.schema.Schema; @@ -116,7 +116,7 @@ public static void defineSchema() throws Exception .applyUnsafe(); expectedRowCount++; } - store.forceBlockingFlush(); + store.forceBlockingFlushToSSTable(); sstable = store.getLiveSSTables().iterator().next(); } @@ -130,12 +130,16 @@ public void writeDataFile_DataInputPlus() @Test public void writeDataFile_RebufferingByteBufDataInputPlus() { - writeDataTestCycle(buffer -> { - EmbeddedChannel channel = new EmbeddedChannel(); - RebufferingByteBufDataInputPlus inputPlus = new RebufferingByteBufDataInputPlus(1 << 10, 1 << 20, channel.config()); - inputPlus.append(Unpooled.wrappedBuffer(buffer)); - return inputPlus; - }); + try (AsyncStreamingInputPlus input = new AsyncStreamingInputPlus(new EmbeddedChannel())) + { + writeDataTestCycle(buffer -> + { + input.append(Unpooled.wrappedBuffer(buffer)); + return input; + }); + + input.requestClosure(); + } } diff --git a/test/unit/org/apache/cassandra/io/util/BufferedDataOutputStreamTest.java b/test/unit/org/apache/cassandra/io/util/BufferedDataOutputStreamTest.java index 7ca2273a05f8..7c1a0da18198 100644 --- a/test/unit/org/apache/cassandra/io/util/BufferedDataOutputStreamTest.java +++ b/test/unit/org/apache/cassandra/io/util/BufferedDataOutputStreamTest.java @@ -610,48 +610,4 @@ public void testWriteSlowByteOrder() throws Exception } } - @Test - public void testWriteExcessSlow() throws Exception - { - try (DataOutputBuffer dob = new DataOutputBuffer(4)) - { - dob.strictFlushing = true; - ByteBuffer buf = ByteBuffer.allocateDirect(8); - buf.putLong(0, 42); - dob.write(buf); - assertEquals(42, ByteBuffer.wrap(dob.toByteArray()).getLong()); - } - } - - @Test - public void testApplyToChannel() throws Exception - { - setUp(); - Object obj = new Object(); - Object retval = ndosp.applyToChannel( channel -> { - ByteBuffer buf = ByteBuffer.allocate(8); - buf.putLong(0, 42); - try - { - channel.write(buf); - } - catch (Exception e) - { - throw new RuntimeException(e); - } - return obj; - }); - assertEquals(obj, retval); - assertEquals(42, ByteBuffer.wrap(generated.toByteArray()).getLong()); - } - - @Test(expected = UnsupportedOperationException.class) - public void testApplyToChannelThrowsForMisaligned() throws Exception - { - setUp(); - ndosp.strictFlushing = true; - ndosp.applyToChannel( channel -> { - return null; - }); - } } diff --git a/test/unit/org/apache/cassandra/locator/AlibabaCloudSnitchTest.java b/test/unit/org/apache/cassandra/locator/AlibabaCloudSnitchTest.java new file mode 100644 index 000000000000..4e8ab164634c --- /dev/null +++ b/test/unit/org/apache/cassandra/locator/AlibabaCloudSnitchTest.java @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.cassandra.locator; + +import static org.junit.Assert.assertEquals; +import java.io.IOException; +import java.net.InetAddress; +import java.util.EnumMap; +import java.util.Map; +import org.apache.cassandra.SchemaLoader; +import org.apache.cassandra.config.DatabaseDescriptor; +import org.apache.cassandra.db.Keyspace; +import org.apache.cassandra.exceptions.ConfigurationException; +import org.apache.cassandra.gms.ApplicationState; +import org.apache.cassandra.gms.Gossiper; +import org.apache.cassandra.gms.VersionedValue; +import org.apache.cassandra.service.StorageService; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; + +public class AlibabaCloudSnitchTest +{ + private static String az; + + @BeforeClass + public static void setup() throws Exception + { + System.setProperty(Gossiper.Props.DISABLE_THREAD_VALIDATION, "true"); + DatabaseDescriptor.daemonInitialization(); + SchemaLoader.mkdirs(); + SchemaLoader.cleanup(); + Keyspace.setInitialized(); + StorageService.instance.initServer(0); + } + + private class TestAlibabaCloudSnitch extends AlibabaCloudSnitch + { + public TestAlibabaCloudSnitch() throws IOException, ConfigurationException + { + super(); + } + + @Override + String alibabaApiCall(String url) throws IOException, ConfigurationException + { + return az; + } + } + + @Test + public void testRac() throws IOException, ConfigurationException + { + az = "cn-hangzhou-f"; + AlibabaCloudSnitch snitch = new TestAlibabaCloudSnitch(); + InetAddressAndPort local = InetAddressAndPort.getByName("127.0.0.1"); + InetAddressAndPort nonlocal = InetAddressAndPort.getByName("127.0.0.7"); + + Gossiper.instance.addSavedEndpoint(nonlocal); + Map stateMap = new EnumMap<>(ApplicationState.class); + stateMap.put(ApplicationState.DC, StorageService.instance.valueFactory.datacenter("cn-shanghai")); + stateMap.put(ApplicationState.RACK, StorageService.instance.valueFactory.datacenter("a")); + Gossiper.instance.getEndpointStateForEndpoint(nonlocal).addApplicationStates(stateMap); + + assertEquals("cn-shanghai", snitch.getDatacenter(nonlocal)); + assertEquals("a", snitch.getRack(nonlocal)); + + assertEquals("cn-hangzhou", snitch.getDatacenter(local)); + assertEquals("f", snitch.getRack(local)); + } + + @Test + public void testNewRegions() throws IOException, ConfigurationException + { + az = "us-east-1a"; + AlibabaCloudSnitch snitch = new TestAlibabaCloudSnitch(); + InetAddressAndPort local = InetAddressAndPort.getByName("127.0.0.1"); + assertEquals("us-east", snitch.getDatacenter(local)); + assertEquals("1a", snitch.getRack(local)); + } + + @AfterClass + public static void tearDown() + { + StorageService.instance.stopClient(); + } + +} diff --git a/test/unit/org/apache/cassandra/locator/DynamicEndpointSnitchTest.java b/test/unit/org/apache/cassandra/locator/DynamicEndpointSnitchTest.java index fbf6e89ff454..069c2227e8e9 100644 --- a/test/unit/org/apache/cassandra/locator/DynamicEndpointSnitchTest.java +++ b/test/unit/org/apache/cassandra/locator/DynamicEndpointSnitchTest.java @@ -20,6 +20,7 @@ import java.io.IOException; import java.util.*; +import java.util.concurrent.TimeUnit; import org.junit.BeforeClass; import org.junit.Test; @@ -30,6 +31,8 @@ import org.apache.cassandra.service.StorageService; import org.apache.cassandra.utils.FBUtilities; +import static java.util.concurrent.TimeUnit.MILLISECONDS; + public class DynamicEndpointSnitchTest { @@ -44,7 +47,7 @@ private static void setScores(DynamicEndpointSnitch dsnitch, int rounds, List= MessagingService.VERSION_40) @@ -65,8 +65,8 @@ private void testAddress(InetAddressAndPort address, int version) throws Excepti } else { - assertEquals(roundtripped.address, address.address); - assertEquals(7000, roundtripped.port); + assertEquals(address.address, roundtripped.address); + assertEquals(InetAddressAndPort.getDefaultPort(), roundtripped.port); } } } diff --git a/test/unit/org/apache/cassandra/net/AsyncChannelPromiseTest.java b/test/unit/org/apache/cassandra/net/AsyncChannelPromiseTest.java new file mode 100644 index 000000000000..c4e62950d936 --- /dev/null +++ b/test/unit/org/apache/cassandra/net/AsyncChannelPromiseTest.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.cassandra.net; + +import org.junit.After; +import org.junit.Test; + +import io.netty.channel.ChannelPromise; +import io.netty.channel.embedded.EmbeddedChannel; + +public class AsyncChannelPromiseTest extends TestAbstractAsyncPromise +{ + @After + public void shutdown() + { + exec.shutdownNow(); + } + + private ChannelPromise newPromise() + { + return new AsyncChannelPromise(new EmbeddedChannel()); + } + + @Test + public void testSuccess() + { + for (boolean setUncancellable : new boolean[] { false, true }) + for (boolean tryOrSet : new boolean[]{ false, true }) + testOneSuccess(newPromise(), setUncancellable, tryOrSet, null, null); + } + + @Test + public void testFailure() + { + for (boolean setUncancellable : new boolean[] { false, true }) + for (boolean tryOrSet : new boolean[] { false, true }) + for (Throwable v : new Throwable[] { null, new NullPointerException() }) + testOneFailure(newPromise(), setUncancellable, tryOrSet, v, null); + } + + + @Test + public void testCancellation() + { + for (boolean interruptIfRunning : new boolean[] { true, false }) + testOneCancellation(newPromise(), interruptIfRunning, null); + } + + + @Test + public void testTimeout() + { + for (boolean setUncancellable : new boolean[] { true, false }) + testOneTimeout(newPromise(), setUncancellable); + } + +} diff --git a/test/unit/org/apache/cassandra/net/AsyncMessageOutputPlusTest.java b/test/unit/org/apache/cassandra/net/AsyncMessageOutputPlusTest.java new file mode 100644 index 000000000000..633207c5f82d --- /dev/null +++ b/test/unit/org/apache/cassandra/net/AsyncMessageOutputPlusTest.java @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.net; + +import java.io.IOException; + +import org.junit.Test; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.embedded.EmbeddedChannel; +import org.apache.cassandra.config.DatabaseDescriptor; +import org.apache.cassandra.net.FrameEncoder.PayloadAllocator; + +import static org.junit.Assert.assertEquals; + +public class AsyncMessageOutputPlusTest +{ + + static + { + DatabaseDescriptor.daemonInitialization(); + } + + @Test + public void testSuccess() throws IOException + { + EmbeddedChannel channel = new TestChannel(4); + ByteBuf read; + try (AsyncMessageOutputPlus out = new AsyncMessageOutputPlus(channel, 32, Integer.MAX_VALUE, PayloadAllocator.simple)) + { + out.writeInt(1); + assertEquals(0, out.flushed()); + assertEquals(0, out.flushedToNetwork()); + assertEquals(4, out.position()); + + out.doFlush(0); + assertEquals(4, out.flushed()); + assertEquals(4, out.flushedToNetwork()); + + out.writeInt(2); + assertEquals(8, out.position()); + assertEquals(4, out.flushed()); + assertEquals(4, out.flushedToNetwork()); + + out.doFlush(0); + assertEquals(8, out.position()); + assertEquals(8, out.flushed()); + assertEquals(4, out.flushedToNetwork()); + + read = channel.readOutbound(); + assertEquals(4, read.readableBytes()); + assertEquals(1, read.getInt(0)); + assertEquals(8, out.flushed()); + assertEquals(8, out.flushedToNetwork()); + + read = channel.readOutbound(); + assertEquals(4, read.readableBytes()); + assertEquals(2, read.getInt(0)); + + out.write(new byte[64]); + assertEquals(72, out.position()); + assertEquals(40, out.flushed()); + assertEquals(40, out.flushedToNetwork()); + + out.doFlush(0); + assertEquals(72, out.position()); + assertEquals(72, out.flushed()); + assertEquals(40, out.flushedToNetwork()); + + read = channel.readOutbound(); + assertEquals(32, read.readableBytes()); + assertEquals(0, read.getLong(0)); + assertEquals(72, out.position()); + assertEquals(72, out.flushed()); + assertEquals(72, out.flushedToNetwork()); + + read = channel.readOutbound(); + assertEquals(32, read.readableBytes()); + assertEquals(0, read.getLong(0)); + } + + } + +} diff --git a/test/unit/org/apache/cassandra/net/AsyncOneResponseTest.java b/test/unit/org/apache/cassandra/net/AsyncOneResponseTest.java index 15b3327be203..3d0508cc07c1 100644 --- a/test/unit/org/apache/cassandra/net/AsyncOneResponseTest.java +++ b/test/unit/org/apache/cassandra/net/AsyncOneResponseTest.java @@ -19,45 +19,35 @@ package org.apache.cassandra.net; import java.util.concurrent.TimeUnit; -import java.util.concurrent.TimeoutException; +import org.junit.Assert; import org.junit.Test; import static org.junit.Assert.assertTrue; public class AsyncOneResponseTest { - @Test(expected = TimeoutException.class) - public void getThrowsExceptionAfterTimeout() throws InterruptedException, TimeoutException + @Test + public void getThrowsExceptionAfterTimeout() throws InterruptedException { AsyncOneResponse response = new AsyncOneResponse<>(); Thread.sleep(2000); - response.get(1, TimeUnit.SECONDS); + Assert.assertFalse(response.await(1, TimeUnit.SECONDS)); } @Test - public void getThrowsExceptionAfterCorrectTimeout() + public void getThrowsExceptionAfterCorrectTimeout() throws InterruptedException { AsyncOneResponse response = new AsyncOneResponse<>(); final long expectedTimeoutMillis = 1000; // Should time out after roughly this time final long schedulingError = 10; // Scheduling is imperfect - boolean hitException = false; // Ensure we actually hit the TimeoutException - - long startTime = System.currentTimeMillis(); - - try - { - response.get(expectedTimeoutMillis, TimeUnit.MILLISECONDS); - } - catch(TimeoutException e) - { - hitException = true; - } - long endTime = System.currentTimeMillis(); + long startTime = System.nanoTime(); + boolean timeout = !response.await(expectedTimeoutMillis, TimeUnit.MILLISECONDS); + long endTime = System.nanoTime(); - assertTrue(hitException); - assertTrue(endTime - startTime > (expectedTimeoutMillis - schedulingError)); + assertTrue(timeout); + assertTrue(TimeUnit.NANOSECONDS.toMillis(endTime - startTime) > (expectedTimeoutMillis - schedulingError)); } } diff --git a/test/unit/org/apache/cassandra/net/AsyncPromiseTest.java b/test/unit/org/apache/cassandra/net/AsyncPromiseTest.java new file mode 100644 index 000000000000..0d2a2e96c25e --- /dev/null +++ b/test/unit/org/apache/cassandra/net/AsyncPromiseTest.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.cassandra.net; + +import org.junit.After; +import org.junit.Test; + +import io.netty.util.concurrent.ImmediateEventExecutor; +import io.netty.util.concurrent.Promise; + +public class AsyncPromiseTest extends TestAbstractAsyncPromise +{ + @After + public void shutdown() + { + exec.shutdownNow(); + } + + private Promise newPromise() + { + return new AsyncPromise<>(ImmediateEventExecutor.INSTANCE); + } + + @Test + public void testSuccess() + { + for (boolean setUncancellable : new boolean[] { false, true }) + for (boolean tryOrSet : new boolean[]{ false, true }) + for (Integer v : new Integer[]{ null, 1 }) + testOneSuccess(newPromise(), setUncancellable, tryOrSet, v, 2); + } + + @Test + public void testFailure() + { + for (boolean setUncancellable : new boolean[] { false, true }) + for (boolean tryOrSet : new boolean[] { false, true }) + for (Throwable v : new Throwable[] { null, new NullPointerException() }) + testOneFailure(newPromise(), setUncancellable, tryOrSet, v, 2); + } + + + @Test + public void testCancellation() + { + for (boolean interruptIfRunning : new boolean[] { true, false }) + testOneCancellation(newPromise(), interruptIfRunning, 2); + } + + + @Test + public void testTimeout() + { + for (boolean setUncancellable : new boolean[] { true, false }) + testOneTimeout(newPromise(), setUncancellable); + } + +} diff --git a/test/unit/org/apache/cassandra/net/async/RebufferingByteBufDataInputPlusTest.java b/test/unit/org/apache/cassandra/net/AsyncStreamingInputPlusTest.java similarity index 60% rename from test/unit/org/apache/cassandra/net/async/RebufferingByteBufDataInputPlusTest.java rename to test/unit/org/apache/cassandra/net/AsyncStreamingInputPlusTest.java index 69df0403a223..b57574741731 100644 --- a/test/unit/org/apache/cassandra/net/async/RebufferingByteBufDataInputPlusTest.java +++ b/test/unit/org/apache/cassandra/net/AsyncStreamingInputPlusTest.java @@ -16,12 +16,13 @@ * limitations under the License. */ -package org.apache.cassandra.net.async; +package org.apache.cassandra.net; import java.io.EOFException; import java.io.IOException; import java.nio.ByteBuffer; import java.nio.channels.WritableByteChannel; +import java.util.concurrent.TimeUnit; import org.junit.After; import org.junit.Assert; @@ -32,65 +33,65 @@ import io.netty.buffer.Unpooled; import io.netty.channel.embedded.EmbeddedChannel; import org.apache.cassandra.io.util.BufferedDataOutputStreamPlus; +import org.apache.cassandra.net.AsyncStreamingInputPlus; +import org.apache.cassandra.net.AsyncStreamingInputPlus.InputTimeoutException; -public class RebufferingByteBufDataInputPlusTest +import static org.junit.Assert.assertFalse; + +public class AsyncStreamingInputPlusTest { private EmbeddedChannel channel; - private RebufferingByteBufDataInputPlus inputPlus; + private AsyncStreamingInputPlus inputPlus; private ByteBuf buf; @Before public void setUp() { channel = new EmbeddedChannel(); - inputPlus = new RebufferingByteBufDataInputPlus(1 << 10, 1 << 11, channel.config()); } @After public void tearDown() { - inputPlus.close(); channel.close(); if (buf != null && buf.refCnt() > 0) buf.release(buf.refCnt()); } - @Test (expected = IllegalArgumentException.class) - public void ctor_badWaterMarks() - { - inputPlus = new RebufferingByteBufDataInputPlus(2, 1, null); - } +// @Test +// public void isOpen() +// { +// Assert.assertTrue(inputPlus.isOpen()); +// inputPlus.requestClosure(); +// Assert.assertFalse(inputPlus.isOpen()); +// } @Test - public void isOpen() - { - Assert.assertTrue(inputPlus.isOpen()); - inputPlus.markClose(); - Assert.assertFalse(inputPlus.isOpen()); - } - - @Test (expected = IllegalStateException.class) public void append_closed() { - inputPlus.markClose(); + inputPlus = new AsyncStreamingInputPlus(channel); + inputPlus.requestClosure(); + inputPlus.close(); buf = channel.alloc().buffer(4); - inputPlus.append(buf); + assertFalse(inputPlus.append(buf)); } @Test - public void append_normal() throws EOFException + public void append_normal() { + inputPlus = new AsyncStreamingInputPlus(channel); int size = 4; buf = channel.alloc().buffer(size); buf.writerIndex(size); inputPlus.append(buf); - Assert.assertEquals(buf.readableBytes(), inputPlus.available()); + Assert.assertEquals(buf.readableBytes(), inputPlus.unsafeAvailable()); } @Test public void read() throws IOException { + inputPlus = new AsyncStreamingInputPlus(channel); // put two buffers of 8 bytes each into the queue. // then read an int, then a long. the latter tests offset into the inputPlus, as well as spanning across queued buffers. // the values of those int/long will both be '42', but spread across both queue buffers. @@ -102,57 +103,60 @@ public void read() throws IOException buf.writeInt(42); buf.writerIndex(8); inputPlus.append(buf); - Assert.assertEquals(16, inputPlus.available()); - - ByteBuffer out = ByteBuffer.allocate(4); - int readCount = inputPlus.read(out); - Assert.assertEquals(4, readCount); - out.flip(); - Assert.assertEquals(42, out.getInt()); - Assert.assertEquals(12, inputPlus.available()); - - out = ByteBuffer.allocate(8); - readCount = inputPlus.read(out); - Assert.assertEquals(8, readCount); - out.flip(); - Assert.assertEquals(42, out.getLong()); - Assert.assertEquals(4, inputPlus.available()); + Assert.assertEquals(16, inputPlus.unsafeAvailable()); + +// ByteBuffer out = ByteBuffer.allocate(4); +// int readCount = inputPlus.read(out); +// Assert.assertEquals(4, readCount); +// out.flip(); +// Assert.assertEquals(42, out.getInt()); +// Assert.assertEquals(12, inputPlus.unsafeAvailable()); + +// out = ByteBuffer.allocate(8); +// readCount = inputPlus.read(out); +// Assert.assertEquals(8, readCount); +// out.flip(); +// Assert.assertEquals(42, out.getLong()); +// Assert.assertEquals(4, inputPlus.unsafeAvailable()); } - @Test (expected = EOFException.class) - public void read_closed() throws IOException - { - inputPlus.markClose(); - ByteBuffer buf = ByteBuffer.allocate(1); - inputPlus.read(buf); - } +// @Test (expected = EOFException.class) +// public void read_closed() throws IOException +// { +// inputPlus.requestClosure(); +// ByteBuffer buf = ByteBuffer.allocate(1); +// inputPlus.read(buf); +// } - @Test (expected = EOFException.class) - public void available_closed() throws EOFException + @Test + public void available_closed() { - inputPlus.markClose(); - inputPlus.available(); + inputPlus = new AsyncStreamingInputPlus(channel); + inputPlus.requestClosure(); + inputPlus.unsafeAvailable(); } @Test - public void available_HappyPath() throws EOFException + public void available_HappyPath() { + inputPlus = new AsyncStreamingInputPlus(channel); int size = 4; buf = channel.alloc().heapBuffer(size); buf.writerIndex(size); inputPlus.append(buf); - Assert.assertEquals(size, inputPlus.available()); + Assert.assertEquals(size, inputPlus.unsafeAvailable()); } @Test - public void available_ClosedButWithBytes() throws EOFException + public void available_ClosedButWithBytes() { + inputPlus = new AsyncStreamingInputPlus(channel); int size = 4; buf = channel.alloc().heapBuffer(size); buf.writerIndex(size); inputPlus.append(buf); - inputPlus.markClose(); - Assert.assertEquals(size, inputPlus.available()); + inputPlus.requestClosure(); + Assert.assertEquals(size, inputPlus.unsafeAvailable()); } @Test @@ -193,6 +197,8 @@ public void consumeUntil_MultipleBuffer_Fails() throws IOException private void consumeUntilTestCycle(int nBuffs, int buffSize, int startOffset, int len) throws IOException { + inputPlus = new AsyncStreamingInputPlus(channel); + byte[] expectedBytes = new byte[len]; int count = 0; for (int j=0; j < nBuffs; j++) @@ -208,16 +214,17 @@ private void consumeUntilTestCycle(int nBuffs, int buffSize, int startOffset, in inputPlus.append(buf); } - inputPlus.append(channel.alloc().buffer(0)); + inputPlus.requestClosure(); TestableWritableByteChannel wbc = new TestableWritableByteChannel(len); inputPlus.skipBytesFully(startOffset); BufferedDataOutputStreamPlus writer = new BufferedDataOutputStreamPlus(wbc); - inputPlus.consumeUntil(writer, len); + inputPlus.consume(buffer -> { writer.write(buffer); return buffer.remaining(); }, len); + writer.close(); - Assert.assertEquals(String.format("Test with {} buffers starting at {} consuming {} bytes", nBuffs, startOffset, - len), len, wbc.writtenBytes.readableBytes()); + Assert.assertEquals(String.format("Test with %d buffers starting at %d consuming %d bytes", nBuffs, startOffset, len), + len, wbc.writtenBytes.readableBytes()); Assert.assertArrayEquals(expectedBytes, wbc.writtenBytes.array()); } @@ -232,7 +239,7 @@ public TestableWritableByteChannel(int initialCapacity) writtenBytes = Unpooled.buffer(initialCapacity); } - public int write(ByteBuffer src) throws IOException + public int write(ByteBuffer src) { int size = src.remaining(); writtenBytes.writeBytes(src); @@ -244,9 +251,30 @@ public boolean isOpen() return isOpen; } - public void close() throws IOException + public void close() { isOpen = false; } - }; + } + + @Test + public void rebufferTimeout() throws IOException + { + long timeoutMillis = 1000; + inputPlus = new AsyncStreamingInputPlus(channel, timeoutMillis, TimeUnit.MILLISECONDS); + + long startNanos = System.nanoTime(); + try + { + inputPlus.readInt(); + Assert.fail("should not have been able to read from the queue"); + } + catch (InputTimeoutException e) + { + // this is the success case, and is expected. any other exception is a failure. + } + + long durationNanos = System.nanoTime() - startNanos; + Assert.assertTrue(TimeUnit.MILLISECONDS.toNanos(timeoutMillis) <= durationNanos); + } } diff --git a/test/unit/org/apache/cassandra/net/AsyncStreamingOutputPlusTest.java b/test/unit/org/apache/cassandra/net/AsyncStreamingOutputPlusTest.java new file mode 100644 index 000000000000..fa5009a3f50e --- /dev/null +++ b/test/unit/org/apache/cassandra/net/AsyncStreamingOutputPlusTest.java @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.net; + +import java.io.IOException; +import java.nio.ByteBuffer; +import org.junit.Test; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.embedded.EmbeddedChannel; +import org.apache.cassandra.config.DatabaseDescriptor; +import org.apache.cassandra.streaming.StreamManager; +import org.apache.cassandra.utils.FBUtilities; + +import static org.junit.Assert.assertEquals; + +public class AsyncStreamingOutputPlusTest +{ + + static + { + DatabaseDescriptor.daemonInitialization(); + } + + @Test + public void testSuccess() throws IOException + { + EmbeddedChannel channel = new TestChannel(4); + ByteBuf read; + try (AsyncStreamingOutputPlus out = new AsyncStreamingOutputPlus(channel)) + { + out.writeInt(1); + assertEquals(0, out.flushed()); + assertEquals(0, out.flushedToNetwork()); + assertEquals(4, out.position()); + + out.doFlush(0); + assertEquals(4, out.flushed()); + assertEquals(4, out.flushedToNetwork()); + + out.writeInt(2); + assertEquals(8, out.position()); + assertEquals(4, out.flushed()); + assertEquals(4, out.flushedToNetwork()); + + out.doFlush(0); + assertEquals(8, out.position()); + assertEquals(8, out.flushed()); + assertEquals(4, out.flushedToNetwork()); + + read = channel.readOutbound(); + assertEquals(4, read.readableBytes()); + assertEquals(1, read.getInt(0)); + assertEquals(8, out.flushed()); + assertEquals(8, out.flushedToNetwork()); + + read = channel.readOutbound(); + assertEquals(4, read.readableBytes()); + assertEquals(2, read.getInt(0)); + + out.write(new byte[16]); + assertEquals(24, out.position()); + assertEquals(8, out.flushed()); + assertEquals(8, out.flushedToNetwork()); + + out.doFlush(0); + assertEquals(24, out.position()); + assertEquals(24, out.flushed()); + assertEquals(24, out.flushedToNetwork()); + + read = channel.readOutbound(); + assertEquals(16, read.readableBytes()); + assertEquals(0, read.getLong(0)); + assertEquals(0, read.getLong(8)); + assertEquals(24, out.position()); + assertEquals(24, out.flushed()); + assertEquals(24, out.flushedToNetwork()); + + out.writeToChannel(alloc -> { + ByteBuffer buffer = alloc.get(16); + buffer.putLong(1); + buffer.putLong(2); + buffer.flip(); + }, new StreamManager.StreamRateLimiter(FBUtilities.getBroadcastAddressAndPort())); + + assertEquals(40, out.position()); + assertEquals(40, out.flushed()); + assertEquals(40, out.flushedToNetwork()); + + read = channel.readOutbound(); + assertEquals(16, read.readableBytes()); + assertEquals(1, read.getLong(0)); + assertEquals(2, read.getLong(8)); + } + + } + +} diff --git a/test/unit/org/apache/cassandra/net/ChunkedInputPlusTest.java b/test/unit/org/apache/cassandra/net/ChunkedInputPlusTest.java new file mode 100644 index 000000000000..f90fcd17ca6d --- /dev/null +++ b/test/unit/org/apache/cassandra/net/ChunkedInputPlusTest.java @@ -0,0 +1,159 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.cassandra.net; + +import java.io.EOFException; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +import com.google.common.collect.Lists; +import org.junit.BeforeClass; +import org.junit.Test; + +import org.apache.cassandra.config.DatabaseDescriptor; +import org.apache.cassandra.net.ChunkedInputPlus; +import org.apache.cassandra.net.ShareableBytes; + +import static org.junit.Assert.*; + +public class ChunkedInputPlusTest +{ + @BeforeClass + public static void setUp() + { + DatabaseDescriptor.clientInitialization(); + } + + @Test(expected = IllegalArgumentException.class) + public void testEmptyIterable() + { + ChunkedInputPlus.of(Collections.emptyList()); + } + + @Test + public void testUnderRead() throws IOException + { + List chunks = Lists.newArrayList( + chunk(1, 1), chunk(2, 2), chunk(3, 3) + ); + + try (ChunkedInputPlus input = ChunkedInputPlus.of(chunks)) + { + byte[] readBytes = new byte[5]; + input.readFully(readBytes); + assertArrayEquals(new byte[] { 1, 2, 2, 3, 3 }, readBytes); + + assertFalse(chunks.get(0).hasRemaining()); + assertFalse(chunks.get(1).hasRemaining()); + assertTrue (chunks.get(2).hasRemaining()); + + assertTrue (chunks.get(0).isReleased()); + assertTrue (chunks.get(1).isReleased()); + assertFalse(chunks.get(2).isReleased()); + } + + // close should release the last chunk + assertTrue(chunks.get(2).isReleased()); + } + + @Test + public void testExactRead() throws IOException + { + List chunks = Lists.newArrayList( + chunk(1, 1), chunk(2, 2), chunk(3, 3) + ); + + try (ChunkedInputPlus input = ChunkedInputPlus.of(chunks)) + { + byte[] readBytes = new byte[6]; + input.readFully(readBytes); + assertArrayEquals(new byte[] { 1, 2, 2, 3, 3, 3 }, readBytes); + + assertFalse(chunks.get(0).hasRemaining()); + assertFalse(chunks.get(1).hasRemaining()); + assertFalse(chunks.get(2).hasRemaining()); + + assertTrue (chunks.get(0).isReleased()); + assertTrue (chunks.get(1).isReleased()); + assertFalse(chunks.get(2).isReleased()); + } + + // close should release the last chunk + assertTrue(chunks.get(2).isReleased()); + } + + @Test + public void testOverRead() throws IOException + { + List chunks = Lists.newArrayList( + chunk(1, 1), chunk(2, 2), chunk(3, 3) + ); + + boolean eofCaught = false; + try (ChunkedInputPlus input = ChunkedInputPlus.of(chunks)) + { + byte[] readBytes = new byte[7]; + input.readFully(readBytes); + assertArrayEquals(new byte[] { 1, 2, 2, 3, 3, 3, 4 }, readBytes); + } + catch (EOFException e) + { + eofCaught = true; + + assertFalse(chunks.get(0).hasRemaining()); + assertFalse(chunks.get(1).hasRemaining()); + assertFalse(chunks.get(2).hasRemaining()); + + assertTrue (chunks.get(2).isReleased()); + assertTrue (chunks.get(1).isReleased()); + assertTrue (chunks.get(2).isReleased()); + } + assertTrue(eofCaught); + } + + @Test + public void testRemainder() throws IOException + { + List chunks = Lists.newArrayList( + chunk(1, 1), chunk(2, 2), chunk(3, 3) + ); + + try (ChunkedInputPlus input = ChunkedInputPlus.of(chunks)) + { + byte[] readBytes = new byte[5]; + input.readFully(readBytes); + assertArrayEquals(new byte[] { 1, 2, 2, 3, 3 }, readBytes); + + assertEquals(1, input.remainder()); + + assertTrue(chunks.get(0).isReleased()); + assertTrue(chunks.get(1).isReleased()); + assertTrue(chunks.get(2).isReleased()); // should be released by remainder() + } + } + + private ShareableBytes chunk(int size, int fill) + { + ByteBuffer buffer = ByteBuffer.allocate(size); + Arrays.fill(buffer.array(), (byte) fill); + return ShareableBytes.wrap(buffer); + } +} \ No newline at end of file diff --git a/test/unit/org/apache/cassandra/net/ConnectionTest.java b/test/unit/org/apache/cassandra/net/ConnectionTest.java new file mode 100644 index 000000000000..17cae7145804 --- /dev/null +++ b/test/unit/org/apache/cassandra/net/ConnectionTest.java @@ -0,0 +1,811 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.net; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicLong; +import java.util.function.Function; +import java.util.function.Supplier; +import java.util.function.ToLongFunction; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Sets; +import com.google.common.util.concurrent.Uninterruptibles; +import org.junit.After; +import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.BeforeClass; +import org.junit.Test; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelOutboundHandlerAdapter; +import io.netty.channel.ChannelPromise; +import org.apache.cassandra.config.DatabaseDescriptor; +import org.apache.cassandra.config.EncryptionOptions; +import org.apache.cassandra.exceptions.RequestFailureReason; +import org.apache.cassandra.exceptions.UnknownColumnException; +import org.apache.cassandra.io.IVersionedAsymmetricSerializer; +import org.apache.cassandra.io.IVersionedSerializer; +import org.apache.cassandra.io.util.DataInputPlus; +import org.apache.cassandra.io.util.DataOutputPlus; +import org.apache.cassandra.locator.InetAddressAndPort; +import org.apache.cassandra.utils.FBUtilities; + +import static java.util.concurrent.TimeUnit.MINUTES; +import static java.util.concurrent.TimeUnit.SECONDS; +import static org.apache.cassandra.net.MessagingService.VERSION_30; +import static org.apache.cassandra.net.MessagingService.VERSION_40; +import static org.apache.cassandra.net.NoPayload.noPayload; +import static org.apache.cassandra.net.MessagingService.current_version; +import static org.apache.cassandra.net.ConnectionUtils.*; +import static org.apache.cassandra.net.ConnectionType.LARGE_MESSAGES; +import static org.apache.cassandra.net.ConnectionType.SMALL_MESSAGES; +import static org.apache.cassandra.net.OutboundConnectionSettings.Framing.LZ4; +import static org.apache.cassandra.net.OutboundConnections.LARGE_MESSAGE_THRESHOLD; +import static org.apache.cassandra.utils.MonotonicClock.approxTime; + +public class ConnectionTest +{ + private static final Logger logger = LoggerFactory.getLogger(ConnectionTest.class); + private static final SocketFactory factory = new SocketFactory(); + + private final Map>> serializers = new HashMap<>(); + private final Map>> handlers = new HashMap<>(); + private final Map> timeouts = new HashMap<>(); + + private void unsafeSetSerializer(Verb verb, Supplier> supplier) throws Throwable + { + serializers.putIfAbsent(verb, verb.unsafeSetSerializer(supplier)); + } + + private void unsafeSetHandler(Verb verb, Supplier> supplier) throws Throwable + { + handlers.putIfAbsent(verb, verb.unsafeSetHandler(supplier)); + } + + private void unsafeSetExpiration(Verb verb, ToLongFunction expiration) throws Throwable + { + timeouts.putIfAbsent(verb, verb.unsafeSetExpiration(expiration)); + } + + @After + public void resetVerbs() throws Throwable + { + for (Map.Entry>> e : serializers.entrySet()) + e.getKey().unsafeSetSerializer(e.getValue()); + serializers.clear(); + for (Map.Entry>> e : handlers.entrySet()) + e.getKey().unsafeSetHandler(e.getValue()); + handlers.clear(); + for (Map.Entry> e : timeouts.entrySet()) + e.getKey().unsafeSetExpiration(e.getValue()); + timeouts.clear(); + } + + @BeforeClass + public static void startup() + { + DatabaseDescriptor.daemonInitialization(); + } + + @AfterClass + public static void cleanup() throws InterruptedException + { + factory.shutdownNow(); + } + + interface SendTest + { + void accept(InboundMessageHandlers inbound, OutboundConnection outbound, InetAddressAndPort endpoint) throws Throwable; + } + + interface ManualSendTest + { + void accept(Settings settings, InboundSockets inbound, OutboundConnection outbound, InetAddressAndPort endpoint) throws Throwable; + } + + static class Settings + { + static final Settings SMALL = new Settings(SMALL_MESSAGES); + static final Settings LARGE = new Settings(LARGE_MESSAGES); + final ConnectionType type; + final Function outbound; + final Function inbound; + Settings(ConnectionType type) + { + this(type, Function.identity(), Function.identity()); + } + Settings(ConnectionType type, Function outbound, + Function inbound) + { + this.type = type; + this.outbound = outbound; + this.inbound = inbound; + } + Settings outbound(Function outbound) + { + return new Settings(type, this.outbound.andThen(outbound), inbound); + } + Settings inbound(Function inbound) + { + return new Settings(type, outbound, this.inbound.andThen(inbound)); + } + Settings override(Settings settings) + { + return new Settings(settings.type != null ? settings.type : type, + outbound.andThen(settings.outbound), + inbound.andThen(settings.inbound)); + } + } + + static final EncryptionOptions.ServerEncryptionOptions encryptionOptions = + new EncryptionOptions.ServerEncryptionOptions() + .withEnabled(true) + .withLegacySslStoragePort(true) + .withOptional(true) + .withInternodeEncryption(EncryptionOptions.ServerEncryptionOptions.InternodeEncryption.all) + .withKeyStore("test/conf/cassandra_ssl_test.keystore") + .withKeyStorePassword("cassandra") + .withTrustStore("test/conf/cassandra_ssl_test.truststore") + .withTrustStorePassword("cassandra") + .withRequireClientAuth(false) + .withCipherSuites("TLS_RSA_WITH_AES_128_CBC_SHA"); + + static final AcceptVersions legacy = new AcceptVersions(VERSION_30, VERSION_30); + + static final List> MODIFIERS = ImmutableList.of( + settings -> settings.outbound(outbound -> outbound.withAcceptVersions(legacy)) + .inbound(inbound -> inbound.withAcceptMessaging(legacy)), + settings -> settings.outbound(outbound -> outbound.withEncryption(encryptionOptions)) + .inbound(inbound -> inbound.withEncryption(encryptionOptions)), + settings -> settings.outbound(outbound -> outbound.withFraming(LZ4)) + ); + + static final List SETTINGS = applyPowerSet( + ImmutableList.of(Settings.SMALL, Settings.LARGE), + MODIFIERS + ); + + private static List applyPowerSet(List settings, List> modifiers) + { + List result = new ArrayList<>(); + for (Set> set : Sets.powerSet(new HashSet<>(modifiers))) + { + for (T s : settings) + { + for (Function f : set) + s = f.apply(s); + result.add(s); + } + } + return result; + } + + private void test(Settings extraSettings, SendTest test) throws Throwable + { + for (Settings s : SETTINGS) + doTest(s.override(extraSettings), test); + } + private void test(SendTest test) throws Throwable + { + for (Settings s : SETTINGS) + doTest(s, test); + } + + private void testManual(ManualSendTest test) throws Throwable + { + for (Settings s : SETTINGS) + doTestManual(s, test); + } + + private void doTest(Settings settings, SendTest test) throws Throwable + { + doTestManual(settings, (ignore, inbound, outbound, endpoint) -> { + inbound.open().sync(); + test.accept(MessagingService.instance().getInbound(endpoint), outbound, endpoint); + }); + } + + private void doTestManual(Settings settings, ManualSendTest test) throws Throwable + { + InetAddressAndPort endpoint = FBUtilities.getBroadcastAddressAndPort(); + InboundConnectionSettings inboundSettings = settings.inbound.apply(new InboundConnectionSettings()) + .withBindAddress(endpoint) + .withSocketFactory(factory); + InboundSockets inbound = new InboundSockets(Collections.singletonList(inboundSettings)); + OutboundConnectionSettings outboundTemplate = settings.outbound.apply(new OutboundConnectionSettings(endpoint)) + .withDefaultReserveLimits() + .withSocketFactory(factory) + .withDefaults(ConnectionCategory.MESSAGING); + ResourceLimits.EndpointAndGlobal reserveCapacityInBytes = new ResourceLimits.EndpointAndGlobal(new ResourceLimits.Concurrent(outboundTemplate.applicationSendQueueReserveEndpointCapacityInBytes), outboundTemplate.applicationSendQueueReserveGlobalCapacityInBytes); + OutboundConnection outbound = new OutboundConnection(settings.type, outboundTemplate, reserveCapacityInBytes); + try + { + logger.info("Running {} {} -> {}", outbound.messagingVersion(), outbound.settings(), inboundSettings); + test.accept(settings, inbound, outbound, endpoint); + } + finally + { + outbound.close(false); + inbound.close().get(30L, SECONDS); + outbound.close(false).get(30L, SECONDS); + resetVerbs(); + MessagingService.instance().messageHandlers.clear(); + } + } + + @Test + public void testSendSmall() throws Throwable + { + test((inbound, outbound, endpoint) -> { + int version = outbound.settings().acceptVersions.max; + int count = 10; + + CountDownLatch deliveryDone = new CountDownLatch(1); + CountDownLatch receiveDone = new CountDownLatch(count); + + unsafeSetHandler(Verb._TEST_1, () -> msg -> receiveDone.countDown()); + Message message = Message.out(Verb._TEST_1, noPayload); + for (int i = 0 ; i < count ; ++i) + outbound.enqueue(message); + + Assert.assertTrue(receiveDone.await(10, SECONDS)); + outbound.unsafeRunOnDelivery(deliveryDone::countDown); + Assert.assertTrue(deliveryDone.await(10, SECONDS)); + + check(outbound).submitted(10) + .sent (10, 10 * message.serializedSize(version)) + .pending ( 0, 0) + .overload ( 0, 0) + .expired ( 0, 0) + .error ( 0, 0) + .check(); + check(inbound) .received (10, 10 * message.serializedSize(version)) + .processed(10, 10 * message.serializedSize(version)) + .pending ( 0, 0) + .expired ( 0, 0) + .error ( 0, 0) + .check(); + }); + } + + @Test + public void testSendLarge() throws Throwable + { + test((inbound, outbound, endpoint) -> { + int version = outbound.settings().acceptVersions.max; + int count = 10; + + CountDownLatch deliveryDone = new CountDownLatch(1); + CountDownLatch receiveDone = new CountDownLatch(count); + + unsafeSetSerializer(Verb._TEST_1, () -> new IVersionedSerializer() + { + public void serialize(Object noPayload, DataOutputPlus out, int version) throws IOException + { + for (int i = 0 ; i < LARGE_MESSAGE_THRESHOLD + 1 ; ++i) + out.writeByte(i); + } + public Object deserialize(DataInputPlus in, int version) throws IOException + { + in.skipBytesFully(LARGE_MESSAGE_THRESHOLD + 1); + return noPayload; + } + public long serializedSize(Object noPayload, int version) + { + return LARGE_MESSAGE_THRESHOLD + 1; + } + }); + unsafeSetHandler(Verb._TEST_1, () -> msg -> receiveDone.countDown()); + Message message = Message.builder(Verb._TEST_1, new Object()) + .withExpiresAt(System.nanoTime() + SECONDS.toNanos(30L)) + .build(); + for (int i = 0 ; i < count ; ++i) + outbound.enqueue(message); + Assert.assertTrue(receiveDone.await(10, SECONDS)); + + outbound.unsafeRunOnDelivery(deliveryDone::countDown); + Assert.assertTrue(deliveryDone.await(10, SECONDS)); + + check(outbound).submitted(10) + .sent (10, 10 * message.serializedSize(version)) + .pending ( 0, 0) + .overload ( 0, 0) + .expired ( 0, 0) + .error ( 0, 0) + .check(); + check(inbound) .received (10, 10 * message.serializedSize(version)) + .processed(10, 10 * message.serializedSize(version)) + .pending ( 0, 0) + .expired ( 0, 0) + .error ( 0, 0) + .check(); + }); + } + + @Test + public void testInsufficientSpace() throws Throwable + { + test(new Settings(null).outbound(settings -> settings + .withApplicationReserveSendQueueCapacityInBytes(1 << 15, new ResourceLimits.Concurrent(1 << 16)) + .withApplicationSendQueueCapacityInBytes(1 << 16)), + (inbound, outbound, endpoint) -> { + + CountDownLatch done = new CountDownLatch(1); + Message message = Message.out(Verb._TEST_1, new Object()); + MessagingService.instance().callbacks.addWithExpiration(new RequestCallback() + { + @Override + public void onFailure(InetAddressAndPort from, RequestFailureReason failureReason) + { + done.countDown(); + } + + @Override + public boolean invokeOnFailure() + { + return true; + } + + @Override + public void onResponse(Message msg) + { + throw new IllegalStateException(); + } + + }, message, endpoint); + AtomicInteger delivered = new AtomicInteger(); + unsafeSetSerializer(Verb._TEST_1, () -> new IVersionedSerializer() + { + public void serialize(Object o, DataOutputPlus out, int version) throws IOException + { + for (int i = 0 ; i <= 4 << 16 ; i += 8L) + out.writeLong(1L); + } + + public Object deserialize(DataInputPlus in, int version) throws IOException + { + in.skipBytesFully(4 << 16); + return null; + } + + public long serializedSize(Object o, int version) + { + return 4 << 16; + } + }); + unsafeSetHandler(Verb._TEST_1, () -> msg -> delivered.incrementAndGet()); + outbound.enqueue(message); + Assert.assertTrue(done.await(10, SECONDS)); + Assert.assertEquals(0, delivered.get()); + check(outbound).submitted( 1) + .sent ( 0, 0) + .pending ( 0, 0) + .overload ( 1, message.serializedSize(current_version)) + .expired ( 0, 0) + .error ( 0, 0) + .check(); + check(inbound) .received ( 0, 0) + .processed( 0, 0) + .pending ( 0, 0) + .expired ( 0, 0) + .error ( 0, 0) + .check(); + }); + } + + @Test + public void testSerializeError() throws Throwable + { + test((inbound, outbound, endpoint) -> { + int version = outbound.settings().acceptVersions.max; + int count = 100; + + CountDownLatch deliveryDone = new CountDownLatch(1); + CountDownLatch receiveDone = new CountDownLatch(90); + + AtomicInteger serialized = new AtomicInteger(); + Message message = Message.builder(Verb._TEST_1, new Object()) + .withExpiresAt(System.nanoTime() + SECONDS.toNanos(30L)) + .build(); + unsafeSetSerializer(Verb._TEST_1, () -> new IVersionedSerializer() + { + public void serialize(Object o, DataOutputPlus out, int version) throws IOException + { + int i = serialized.incrementAndGet(); + if (0 == (i & 15)) + { + if (0 == (i & 16)) + out.writeByte(i); + throw new IOException(); + } + + if (1 != (i & 31)) + out.writeByte(i); + } + + public Object deserialize(DataInputPlus in, int version) throws IOException + { + in.readByte(); + return null; + } + + public long serializedSize(Object o, int version) + { + return 1; + } + }); + + unsafeSetHandler(Verb._TEST_1, () -> msg -> receiveDone.countDown()); + for (int i = 0 ; i < count ; ++i) + outbound.enqueue(message); + + Assert.assertTrue(receiveDone.await(1, MINUTES)); + outbound.unsafeRunOnDelivery(deliveryDone::countDown); + Assert.assertTrue(deliveryDone.await(10, SECONDS)); + + check(outbound).submitted(100) + .sent ( 90, 90 * message.serializedSize(version)) + .pending ( 0, 0) + .overload ( 0, 0) + .expired ( 0, 0) + .error ( 10, 10 * message.serializedSize(version)) + .check(); + check(inbound) .received ( 90, 90 * message.serializedSize(version)) + .processed( 90, 90 * message.serializedSize(version)) + .pending ( 0, 0) + .expired ( 0, 0) + .error ( 0, 0) + .check(); + }); + } + + @Test + public void testTimeout() throws Throwable + { + test((inbound, outbound, endpoint) -> { + int version = outbound.settings().acceptVersions.max; + int count = 10; + CountDownLatch enqueueDone = new CountDownLatch(1); + CountDownLatch deliveryDone = new CountDownLatch(1); + AtomicInteger delivered = new AtomicInteger(); + Verb._TEST_1.unsafeSetHandler(() -> msg -> delivered.incrementAndGet()); + Message message = Message.builder(Verb._TEST_1, noPayload) + .withExpiresAt(approxTime.now() + TimeUnit.DAYS.toNanos(1L)) + .build(); + long sentSize = message.serializedSize(version); + outbound.enqueue(message); + long timeoutMillis = 10L; + while (delivered.get() < 1); + outbound.unsafeRunOnDelivery(() -> Uninterruptibles.awaitUninterruptibly(enqueueDone, 1L, TimeUnit.DAYS)); + message = Message.builder(Verb._TEST_1, noPayload) + .withExpiresAt(approxTime.now() + TimeUnit.MILLISECONDS.toNanos(timeoutMillis)) + .build(); + for (int i = 0 ; i < count ; ++i) + outbound.enqueue(message); + Uninterruptibles.sleepUninterruptibly(timeoutMillis * 2, TimeUnit.MILLISECONDS); + enqueueDone.countDown(); + outbound.unsafeRunOnDelivery(deliveryDone::countDown); + Assert.assertTrue(deliveryDone.await(1, MINUTES)); + Assert.assertEquals(1, delivered.get()); + check(outbound).submitted( 11) + .sent ( 1, sentSize) + .pending ( 0, 0) + .overload ( 0, 0) + .expired ( 10, 10 * message.serializedSize(current_version)) + .error ( 0, 0) + .check(); + check(inbound) .received ( 1, sentSize) + .processed( 1, sentSize) + .pending ( 0, 0) + .expired ( 0, 0) + .error ( 0, 0) + .check(); + }); + } + + @Test + public void testPre40() throws Throwable + { + MessagingService.instance().versions.set(FBUtilities.getBroadcastAddressAndPort(), + MessagingService.VERSION_30); + + try + { + test((inbound, outbound, endpoint) -> { + CountDownLatch done = new CountDownLatch(1); + unsafeSetHandler(Verb._TEST_1, + () -> (msg) -> done.countDown()); + + Message message = Message.out(Verb._TEST_1, noPayload); + outbound.enqueue(message); + Assert.assertTrue(done.await(1, MINUTES)); + Assert.assertTrue(outbound.isConnected()); + }); + } + finally + { + MessagingService.instance().versions.set(FBUtilities.getBroadcastAddressAndPort(), + current_version); + } + } + + @Test + public void testCloseIfEndpointDown() throws Throwable + { + testManual((settings, inbound, outbound, endpoint) -> { + Message message = Message.builder(Verb._TEST_1, noPayload) + .withExpiresAt(System.nanoTime() + SECONDS.toNanos(30L)) + .build(); + + for (int i = 0 ; i < 1000 ; ++i) + outbound.enqueue(message); + + outbound.close(true).get(10L, MINUTES); + }); + } + + @Test + public void testMessagePurging() throws Throwable + { + testManual((settings, inbound, outbound, endpoint) -> { + Runnable testWhileDisconnected = () -> { + try + { + for (int i = 0; i < 5; i++) + { + Message message = Message.builder(Verb._TEST_1, noPayload) + .withExpiresAt(System.nanoTime() + TimeUnit.MILLISECONDS.toNanos(50L)) + .build(); + outbound.enqueue(message); + Assert.assertFalse(outbound.isConnected()); + Assert.assertEquals(1, outbound.pendingCount()); + CompletableFuture.runAsync(() -> { + while (outbound.pendingCount() > 0 && !Thread.interrupted()) {} + }).get(10, SECONDS); + // Message should have been purged + Assert.assertEquals(0, outbound.pendingCount()); + } + } + catch (Throwable t) + { + throw new RuntimeException(t); + } + }; + + testWhileDisconnected.run(); + + try + { + inbound.open().sync(); + CountDownLatch receiveDone = new CountDownLatch(1); + CountDownLatch deliveryDone = new CountDownLatch(1); + + unsafeSetHandler(Verb._TEST_1, () -> msg -> receiveDone.countDown()); + outbound.enqueue(Message.out(Verb._TEST_1, noPayload)); + Assert.assertEquals(1, outbound.pendingCount()); + outbound.unsafeRunOnDelivery(deliveryDone::countDown); + + Assert.assertTrue(receiveDone.await(10, SECONDS)); + Assert.assertTrue(deliveryDone.await(10, SECONDS)); + Assert.assertEquals(0, receiveDone.getCount()); + Assert.assertEquals(0, outbound.pendingCount()); + } + finally + { + inbound.close().get(10, SECONDS); + // Wait until disconnected + CompletableFuture.runAsync(() -> { + while (outbound.isConnected() && !Thread.interrupted()) {} + }).get(10, SECONDS); + } + + testWhileDisconnected.run(); + }); + } + + @Test + public void testMessageDeliveryOnReconnect() throws Throwable + { + testManual((settings, inbound, outbound, endpoint) -> { + try + { + inbound.open().sync(); + CountDownLatch done = new CountDownLatch(1); + unsafeSetHandler(Verb._TEST_1, () -> msg -> done.countDown()); + outbound.enqueue(Message.out(Verb._TEST_1, noPayload)); + Assert.assertTrue(done.await(10, SECONDS)); + Assert.assertEquals(done.getCount(), 0); + + // Simulate disconnect + inbound.close().get(10, SECONDS); + MessagingService.instance().removeInbound(endpoint); + inbound = new InboundSockets(settings.inbound.apply(new InboundConnectionSettings())); + inbound.open().sync(); + + CountDownLatch latch2 = new CountDownLatch(1); + unsafeSetHandler(Verb._TEST_1, () -> msg -> latch2.countDown()); + outbound.enqueue(Message.out(Verb._TEST_1, noPayload)); + + latch2.await(10, SECONDS); + Assert.assertEquals(latch2.getCount(), 0); + } + finally + { + inbound.close().get(10, SECONDS); + outbound.close(false).get(10, SECONDS); + } + }); + } + + @Test + public void testRecoverableCorruptedMessageDelivery() throws Throwable + { + test((inbound, outbound, endpoint) -> { + int version = outbound.settings().acceptVersions.max; + if (version < VERSION_40) + return; + + AtomicInteger counter = new AtomicInteger(); + unsafeSetSerializer(Verb._TEST_1, () -> new IVersionedSerializer() + { + public void serialize(Object o, DataOutputPlus out, int version) throws IOException + { + out.writeInt((Integer) o); + } + + public Object deserialize(DataInputPlus in, int version) throws IOException + { + if (counter.getAndIncrement() == 3) + throw new UnknownColumnException(""); + + return in.readInt(); + } + + public long serializedSize(Object o, int version) + { + return Integer.BYTES; + } + }); + + // Connect + connect(outbound); + + CountDownLatch latch = new CountDownLatch(4); + unsafeSetHandler(Verb._TEST_1, () -> message -> latch.countDown()); + for (int i = 0; i < 5; i++) + outbound.enqueue(Message.out(Verb._TEST_1, 0xffffffff)); + + latch.await(10, SECONDS); + Assert.assertEquals(0, latch.getCount()); + Assert.assertEquals(6, counter.get()); + }); + } + + @Test + public void testCRCCorruption() throws Throwable + { + test((inbound, outbound, endpoint) -> { + int version = outbound.settings().acceptVersions.max; + if (version < VERSION_40) + return; + + unsafeSetSerializer(Verb._TEST_1, () -> new IVersionedSerializer() + { + public void serialize(Object o, DataOutputPlus out, int version) throws IOException + { + out.writeInt((Integer) o); + } + + public Object deserialize(DataInputPlus in, int version) throws IOException + { + return in.readInt(); + } + + public long serializedSize(Object o, int version) + { + return Integer.BYTES; + } + }); + + connect(outbound); + + outbound.unsafeGetChannel().pipeline().addFirst(new ChannelOutboundHandlerAdapter() { + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { + ByteBuf bb = (ByteBuf) msg; + bb.setByte(0, 0xAB); + ctx.write(msg, promise); + } + }); + outbound.enqueue(Message.out(Verb._TEST_1, 0xffffffff)); + CompletableFuture.runAsync(() -> { + while (outbound.isConnected() && !Thread.interrupted()) {} + }).get(10, SECONDS); + Assert.assertFalse(outbound.isConnected()); + // TODO: count corruptions + + connect(outbound); + }); + } + + @Test + public void testAcquireReleaseOutbound() throws Throwable + { + test((inbound, outbound, endpoint) -> { + ExecutorService executor = Executors.newFixedThreadPool(100); + int acquireStep = 123; + Assert.assertTrue(outbound.unsafeAcquireCapacity(100 * 10000, 100 * 10000 * acquireStep)); + AtomicLong acquisitionFailures = new AtomicLong(); + for (int i = 0; i < 100; i++) + { + executor.submit(() -> { + for (int j = 0; j < 10000; j++) + { + if (!outbound.unsafeAcquireCapacity(acquireStep)) + acquisitionFailures.incrementAndGet(); + } + + }); + } + + for (int i = 0; i < 100; i++) + { + executor.submit(() -> { + for (int j = 0; j < 10000; j++) + outbound.unsafeReleaseCapacity(acquireStep); + }); + } + + executor.shutdown(); + executor.awaitTermination(10, TimeUnit.SECONDS); + + // We can release more than we acquire, which certainly should not happen in + // real life, but since it's a test just for acquisition and release, it is fine + Assert.assertEquals(100 * 10000 * acquireStep - (acquisitionFailures.get() * acquireStep), outbound.pendingBytes()); + }); + } + + private void connect(OutboundConnection outbound) throws Throwable + { + CountDownLatch latch = new CountDownLatch(1); + unsafeSetHandler(Verb._TEST_1, () -> message -> latch.countDown()); + outbound.enqueue(Message.out(Verb._TEST_1, 0xffffffff)); + latch.await(10, SECONDS); + Assert.assertEquals(0, latch.getCount()); + Assert.assertTrue(outbound.isConnected()); + } + +} diff --git a/test/unit/org/apache/cassandra/net/ConnectionUtils.java b/test/unit/org/apache/cassandra/net/ConnectionUtils.java new file mode 100644 index 000000000000..e3917851d855 --- /dev/null +++ b/test/unit/org/apache/cassandra/net/ConnectionUtils.java @@ -0,0 +1,253 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.net; + +import java.util.concurrent.TimeUnit; + +import com.google.common.util.concurrent.Uninterruptibles; +import org.junit.Assert; + +import org.apache.cassandra.net.InboundMessageHandlers; +import org.apache.cassandra.net.OutboundConnection; + +public class ConnectionUtils +{ + public interface FailCheck + { + public void accept(String message, long expected, long actual); + } + + public static class OutboundCountChecker + { + private final OutboundConnection connection; + private long submitted; + private long pending, pendingBytes; + private long sent, sentBytes; + private long overload, overloadBytes; + private long expired, expiredBytes; + private long error, errorBytes; + private boolean checkSubmitted, checkPending, checkSent, checkOverload, checkExpired, checkError; + + private OutboundCountChecker(OutboundConnection connection) + { + this.connection = connection; + } + + public OutboundCountChecker submitted(long count) + { + submitted = count; + checkSubmitted = true; + return this; + } + + public OutboundCountChecker pending(long count, long bytes) + { + pending = count; + pendingBytes = bytes; + checkPending = true; + return this; + } + + public OutboundCountChecker sent(long count, long bytes) + { + sent = count; + sentBytes = bytes; + checkSent = true; + return this; + } + + public OutboundCountChecker overload(long count, long bytes) + { + overload = count; + overloadBytes = bytes; + checkOverload = true; + return this; + } + + public OutboundCountChecker expired(long count, long bytes) + { + expired = count; + expiredBytes = bytes; + checkExpired = true; + return this; + } + + public OutboundCountChecker error(long count, long bytes) + { + error = count; + errorBytes = bytes; + checkError = true; + return this; + } + + public void check() + { + doCheck(Assert::assertEquals); + } + + public void check(FailCheck failCheck) + { + doCheck((message, expect, actual) -> { if (expect != actual) failCheck.accept(message, expect, actual); }); + } + + private void doCheck(FailCheck testAndFailCheck) + { + if (checkSubmitted) + { + testAndFailCheck.accept("submitted count values don't match", submitted, connection.submittedCount()); + } + if (checkPending) + { + testAndFailCheck.accept("pending count values don't match", pending, connection.pendingCount()); + testAndFailCheck.accept("pending bytes values don't match", pendingBytes, connection.pendingBytes()); + } + if (checkSent) + { + testAndFailCheck.accept("sent count values don't match", sent, connection.sentCount()); + testAndFailCheck.accept("sent bytes values don't match", sentBytes, connection.sentBytes()); + } + if (checkOverload) + { + testAndFailCheck.accept("overload count values don't match", overload, connection.overloadedCount()); + testAndFailCheck.accept("overload bytes values don't match", overloadBytes, connection.overloadedBytes()); + } + if (checkExpired) + { + testAndFailCheck.accept("expired count values don't match", expired, connection.expiredCount()); + testAndFailCheck.accept("expired bytes values don't match", expiredBytes, connection.expiredBytes()); + } + if (checkError) + { + testAndFailCheck.accept("error count values don't match", error, connection.errorCount()); + testAndFailCheck.accept("error bytes values don't match", errorBytes, connection.errorBytes()); + } + } + } + + public static class InboundCountChecker + { + private final InboundMessageHandlers connection; + private long scheduled, scheduledBytes; + private long received, receivedBytes; + private long processed, processedBytes; + private long expired, expiredBytes; + private long error, errorBytes; + private boolean checkScheduled, checkReceived, checkProcessed, checkExpired, checkError; + + private InboundCountChecker(InboundMessageHandlers connection) + { + this.connection = connection; + } + + public InboundCountChecker pending(long count, long bytes) + { + scheduled = count; + scheduledBytes = bytes; + checkScheduled = true; + return this; + } + + public InboundCountChecker received(long count, long bytes) + { + received = count; + receivedBytes = bytes; + checkReceived = true; + return this; + } + + public InboundCountChecker processed(long count, long bytes) + { + processed = count; + processedBytes = bytes; + checkProcessed = true; + return this; + } + + public InboundCountChecker expired(long count, long bytes) + { + expired = count; + expiredBytes = bytes; + checkExpired = true; + return this; + } + + public InboundCountChecker error(long count, long bytes) + { + error = count; + errorBytes = bytes; + checkError = true; + return this; + } + + public void check() + { + doCheck(Assert::assertEquals); + } + + public void check(FailCheck failCheck) + { + doCheck((message, expect, actual) -> { if (expect != actual) failCheck.accept(message, expect, actual); }); + } + + private void doCheck(FailCheck testAndFailCheck) + { + if (checkReceived) + { + testAndFailCheck.accept("received count values don't match", received, connection.receivedCount()); + testAndFailCheck.accept("received bytes values don't match", receivedBytes, connection.receivedBytes()); + } + if (checkProcessed) + { + testAndFailCheck.accept("processed count values don't match", processed, connection.processedCount()); + testAndFailCheck.accept("processed bytes values don't match", processedBytes, connection.processedBytes()); + } + if (checkExpired) + { + testAndFailCheck.accept("expired count values don't match", expired, connection.expiredCount()); + testAndFailCheck.accept("expired bytes values don't match", expiredBytes, connection.expiredBytes()); + } + if (checkError) + { + testAndFailCheck.accept("error count values don't match", error, connection.errorCount()); + testAndFailCheck.accept("error bytes values don't match", errorBytes, connection.errorBytes()); + } + if (checkScheduled) + { + // scheduled cannot relied upon to not race with completion of the task, + // so if it is currently above the value we expect, sleep for a bit + if (scheduled < connection.scheduledCount()) + for (int i = 0; i < 10 && scheduled < connection.scheduledCount() ; ++i) + Uninterruptibles.sleepUninterruptibly(1L, TimeUnit.MILLISECONDS); + testAndFailCheck.accept("scheduled count values don't match", scheduled, connection.scheduledCount()); + testAndFailCheck.accept("scheduled bytes values don't match", scheduledBytes, connection.scheduledBytes()); + } + } + } + + public static OutboundCountChecker check(OutboundConnection outbound) + { + return new OutboundCountChecker(outbound); + } + + public static InboundCountChecker check(InboundMessageHandlers inbound) + { + return new InboundCountChecker(inbound); + } + +} diff --git a/test/unit/org/apache/cassandra/net/ForwardToContainerTest.java b/test/unit/org/apache/cassandra/net/ForwardingInfoTest.java similarity index 89% rename from test/unit/org/apache/cassandra/net/ForwardToContainerTest.java rename to test/unit/org/apache/cassandra/net/ForwardingInfoTest.java index 195d734c1bf9..16dec9f34639 100644 --- a/test/unit/org/apache/cassandra/net/ForwardToContainerTest.java +++ b/test/unit/org/apache/cassandra/net/ForwardingInfoTest.java @@ -33,7 +33,7 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; -public class ForwardToContainerTest +public class ForwardingInfoTest { @Test public void testCurrent() throws Exception @@ -57,20 +57,20 @@ private void testVersion(int version) throws Exception InetAddressAndPort.getByName("2001:0db8:0000:0000:0000:ff00:0042:8329"), InetAddressAndPort.getByName("[2001:0db8:0000:0000:0000:ff00:0042:8329]:7000")); - ForwardToContainer ftc = new ForwardToContainer(addresses, new int[] { 44, 45, 46, 47, 48, 49 }); + ForwardingInfo ftc = new ForwardingInfo(addresses, new long[] { 44, 45, 46, 47, 48, 49 }); ByteBuffer buffer; try (DataOutputBuffer dob = new DataOutputBuffer()) { - ForwardToSerializer.instance.serialize(ftc, dob, version); + ForwardingInfo.serializer.serialize(ftc, dob, version); buffer = dob.buffer(); } - assertEquals(buffer.remaining(), ForwardToSerializer.instance.serializedSize(ftc, version)); + assertEquals(buffer.remaining(), ForwardingInfo.serializer.serializedSize(ftc, version)); - ForwardToContainer deserialized; + ForwardingInfo deserialized; try (DataInputBuffer dib = new DataInputBuffer(buffer, false)) { - deserialized = ForwardToSerializer.instance.deserialize(dib, version); + deserialized = ForwardingInfo.serializer.deserialize(dib, version); } assertTrue(Arrays.equals(ftc.messageIds, deserialized.messageIds)); diff --git a/test/unit/org/apache/cassandra/net/FramingTest.java b/test/unit/org/apache/cassandra/net/FramingTest.java new file mode 100644 index 000000000000..8a7f4283bdfa --- /dev/null +++ b/test/unit/org/apache/cassandra/net/FramingTest.java @@ -0,0 +1,434 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.net; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.security.SecureRandom; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Random; + +import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.BeforeClass; +import org.junit.Test; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import io.netty.buffer.ByteBuf; +import org.apache.cassandra.config.DatabaseDescriptor; +import org.apache.cassandra.io.IVersionedSerializer; +import org.apache.cassandra.io.compress.BufferType; +import org.apache.cassandra.io.util.DataInputPlus; +import org.apache.cassandra.io.util.DataOutputBuffer; +import org.apache.cassandra.io.util.DataOutputPlus; +import org.apache.cassandra.utils.memory.BufferPool; +import org.apache.cassandra.utils.vint.VIntCoding; + +import static java.lang.Math.*; +import static org.apache.cassandra.net.MessagingService.VERSION_30; +import static org.apache.cassandra.net.MessagingService.VERSION_3014; +import static org.apache.cassandra.net.MessagingService.current_version; +import static org.apache.cassandra.net.MessagingService.minimum_version; +import static org.apache.cassandra.net.OutboundConnections.LARGE_MESSAGE_THRESHOLD; +import static org.apache.cassandra.net.ShareableBytes.wrap; + +// TODO: test corruption +// TODO: use a different random seed each time +// TODO: use quick theories +public class FramingTest +{ + private static final Logger logger = LoggerFactory.getLogger(FramingTest.class); + + @BeforeClass + public static void begin() throws NoSuchFieldException, IllegalAccessException + { + DatabaseDescriptor.daemonInitialization(); + Verb._TEST_1.unsafeSetSerializer(() -> new IVersionedSerializer() + { + + public void serialize(byte[] t, DataOutputPlus out, int version) throws IOException + { + out.writeUnsignedVInt(t.length); + out.write(t); + } + + public byte[] deserialize(DataInputPlus in, int version) throws IOException + { + byte[] r = new byte[(int) in.readUnsignedVInt()]; + in.readFully(r); + return r; + } + + public long serializedSize(byte[] t, int version) + { + return VIntCoding.computeUnsignedVIntSize(t.length) + t.length; + } + }); + } + + @AfterClass + public static void after() throws NoSuchFieldException, IllegalAccessException + { + Verb._TEST_1.unsafeSetSerializer(() -> null); + } + + private static class SequenceOfFrames + { + final List original; + final int[] boundaries; + final ShareableBytes frames; + + private SequenceOfFrames(List original, int[] boundaries, ByteBuffer frames) + { + this.original = original; + this.boundaries = boundaries; + this.frames = wrap(frames); + } + } + + @Test + public void testRandomLZ4() + { + testSomeFrames(FrameEncoderLZ4.fastInstance, FrameDecoderLZ4.fast(GlobalBufferPoolAllocator.instance)); + } + + @Test + public void testRandomCrc() + { + testSomeFrames(FrameEncoderCrc.instance, FrameDecoderCrc.create(GlobalBufferPoolAllocator.instance)); + } + + private void testSomeFrames(FrameEncoder encoder, FrameDecoder decoder) + { + long seed = new SecureRandom().nextLong(); + logger.info("seed: {}, decoder: {}", seed, decoder.getClass().getSimpleName()); + Random random = new Random(seed); + for (int i = 0 ; i < 1000 ; ++i) + testRandomSequenceOfFrames(random, encoder, decoder); + } + + private void testRandomSequenceOfFrames(Random random, FrameEncoder encoder, FrameDecoder decoder) + { + SequenceOfFrames sequenceOfFrames = sequenceOfFrames(random, encoder); + + List uncompressed = sequenceOfFrames.original; + ShareableBytes frames = sequenceOfFrames.frames; + int[] boundaries = sequenceOfFrames.boundaries; + + int end = frames.get().limit(); + List out = new ArrayList<>(); + int prevBoundary = -1; + for (int i = 0 ; i < end ; ) + { + int limit = i + random.nextInt(1 + end - i); + decoder.decode(out, frames.slice(i, limit)); + int boundary = Arrays.binarySearch(boundaries, limit); + if (boundary < 0) boundary = -2 -boundary; + + while (prevBoundary < boundary) + { + ++prevBoundary; + Assert.assertTrue(out.size() >= 1 + prevBoundary); + verify(uncompressed.get(prevBoundary), ((FrameDecoder.IntactFrame) out.get(prevBoundary)).contents); + } + i = limit; + } + for (FrameDecoder.Frame frame : out) + frame.release(); + frames.release(); + Assert.assertNull(decoder.stash); + Assert.assertTrue(decoder.frames.isEmpty()); + } + + private static void verify(byte[] expect, ShareableBytes actual) + { + verify(expect, 0, expect.length, actual); + } + + private static void verify(byte[] expect, int start, int end, ShareableBytes actual) + { + byte[] fetch = new byte[end - start]; + Assert.assertEquals(end - start, actual.remaining()); + actual.get().get(fetch); + boolean equals = true; + for (int i = start ; equals && i < end ; ++i) + equals = expect[i] == fetch[i - start]; + if (!equals) + Assert.assertArrayEquals(Arrays.copyOfRange(expect, start, end), fetch); + } + + private static SequenceOfFrames sequenceOfFrames(Random random, FrameEncoder encoder) + { + int frameCount = 1 + random.nextInt(8); + List uncompressed = new ArrayList<>(); + List compressed = new ArrayList<>(); + int[] cumulativeCompressedLength = new int[frameCount]; + for (int i = 0 ; i < frameCount ; ++i) + { + byte[] bytes = randomishBytes(random, 1, 1 << 15); + uncompressed.add(bytes); + + FrameEncoder.Payload payload = encoder.allocator().allocate(true, bytes.length); + payload.buffer.put(bytes); + payload.finish(); + + ByteBuf buffer = encoder.encode(true, payload.buffer); + compressed.add(buffer); + cumulativeCompressedLength[i] = (i == 0 ? 0 : cumulativeCompressedLength[i - 1]) + buffer.readableBytes(); + } + + ByteBuffer frames = BufferPool.getAtLeast(cumulativeCompressedLength[frameCount - 1], BufferType.OFF_HEAP); + for (ByteBuf buffer : compressed) + { + frames.put(buffer.internalNioBuffer(buffer.readerIndex(), buffer.readableBytes())); + buffer.release(); + } + frames.flip(); + return new SequenceOfFrames(uncompressed, cumulativeCompressedLength, frames); + } + + @Test + public void burnRandomLegacy() + { + burnRandomLegacy(1000); + } + + private void burnRandomLegacy(int count) + { + SecureRandom seed = new SecureRandom(); + Random random = new Random(); + for (int i = 0 ; i < count ; ++i) + { + long innerSeed = seed.nextLong(); + float ratio = seed.nextFloat(); + int version = minimum_version + random.nextInt(1 + current_version - minimum_version); + logger.debug("seed: {}, ratio: {}, version: {}", innerSeed, ratio, version); + random.setSeed(innerSeed); + testRandomSequenceOfMessages(random, ratio, version, new FrameDecoderLegacy(GlobalBufferPoolAllocator.instance, version)); + } + } + + @Test + public void testRandomLegacy() + { + testRandomLegacy(250); + } + + private void testRandomLegacy(int count) + { + SecureRandom seeds = new SecureRandom(); + for (int messagingVersion : new int[] { VERSION_30, VERSION_3014, current_version}) + { + FrameDecoder decoder = new FrameDecoderLegacy(GlobalBufferPoolAllocator.instance, messagingVersion); + testSomeMessages(seeds.nextLong(), count, 0.0f, messagingVersion, decoder); + testSomeMessages(seeds.nextLong(), count, 0.1f, messagingVersion, decoder); + testSomeMessages(seeds.nextLong(), count, 0.95f, messagingVersion, decoder); + testSomeMessages(seeds.nextLong(), count, 1.0f, messagingVersion, decoder); + } + } + + private void testSomeMessages(long seed, int count, float largeRatio, int messagingVersion, FrameDecoder decoder) + { + logger.info("seed: {}, iterations: {}, largeRatio: {}, messagingVersion: {}, decoder: {}", seed, count, largeRatio, messagingVersion, decoder.getClass().getSimpleName()); + Random random = new Random(seed); + for (int i = 0 ; i < count ; ++i) + { + long innerSeed = random.nextLong(); + logger.debug("inner seed: {}, iteration: {}", innerSeed, i); + random.setSeed(innerSeed); + testRandomSequenceOfMessages(random, largeRatio, messagingVersion, decoder); + } + } + + private void testRandomSequenceOfMessages(Random random, float largeRatio, int messagingVersion, FrameDecoder decoder) + { + SequenceOfFrames sequenceOfMessages = sequenceOfMessages(random, largeRatio, messagingVersion); + + List messages = sequenceOfMessages.original; + ShareableBytes stream = sequenceOfMessages.frames; + + int end = stream.get().limit(); + List out = new ArrayList<>(); + + int messageStart = 0; + int messageIndex = 0; + for (int i = 0 ; i < end ; ) + { + int limit = i + random.nextInt(1 + end - i); + decoder.decode(out, stream.slice(i, limit)); + + int outIndex = 0; + byte[] message = messages.get(messageIndex); + if (i > messageStart) + { + int start; + if (message.length <= LARGE_MESSAGE_THRESHOLD) + { + start = 0; + } + else if (!lengthIsReadable(message, i - messageStart, messagingVersion)) + { + // we should have an initial frame containing only some prefix of the message (probably 64 bytes) + // that was stashed only to decide how big the message was + FrameDecoder.IntactFrame frame = (FrameDecoder.IntactFrame) out.get(outIndex++); + Assert.assertFalse(frame.isSelfContained); + start = frame.contents.remaining(); + verify(message, 0, frame.contents.remaining(), frame.contents); + } + else + { + start = i - messageStart; + } + + if (limit >= message.length + messageStart) + { + FrameDecoder.IntactFrame frame = (FrameDecoder.IntactFrame) out.get(outIndex++); + Assert.assertEquals(start == 0, frame.isSelfContained); + // verify remainder of a large message, or a single fully stashed small message + verify(message, start, message.length, frame.contents); + + messageStart += message.length; + if (++messageIndex < messages.size()) + message = messages.get(messageIndex); + } + else if (message.length > LARGE_MESSAGE_THRESHOLD) + { + FrameDecoder.IntactFrame frame = (FrameDecoder.IntactFrame) out.get(outIndex++); + Assert.assertFalse(frame.isSelfContained); + // verify next portion of a large message + verify(message, start, limit - messageStart, frame.contents); + + Assert.assertEquals(outIndex, out.size()); + for (FrameDecoder.Frame f : out) + f.release(); + out.clear(); + i = limit; + continue; + } + } + + // message is fresh + int beginFrameIndex = messageIndex; + while (messageStart + message.length <= limit) + { + messageStart += message.length; + if (++messageIndex < messages.size()) + message = messages.get(messageIndex); + } + + if (beginFrameIndex < messageIndex) + { + FrameDecoder.IntactFrame frame = (FrameDecoder.IntactFrame) out.get(outIndex++); + Assert.assertTrue(frame.isSelfContained); + while (beginFrameIndex < messageIndex) + { + byte[] m = messages.get(beginFrameIndex); + ShareableBytes bytesToVerify = frame.contents.sliceAndConsume(m.length); + verify(m, bytesToVerify); + bytesToVerify.release(); + ++beginFrameIndex; + } + Assert.assertFalse(frame.contents.hasRemaining()); + } + + if (limit > messageStart + && message.length > LARGE_MESSAGE_THRESHOLD + && lengthIsReadable(message, limit - messageStart, messagingVersion)) + { + FrameDecoder.IntactFrame frame = (FrameDecoder.IntactFrame) out.get(outIndex++); + Assert.assertFalse(frame.isSelfContained); + verify(message, 0, limit - messageStart, frame.contents); + } + + Assert.assertEquals(outIndex, out.size()); + for (FrameDecoder.Frame frame : out) + frame.release(); + out.clear(); + + i = limit; + } + stream.release(); + Assert.assertTrue(stream.isReleased()); + Assert.assertNull(decoder.stash); + Assert.assertTrue(decoder.frames.isEmpty()); + } + + private static boolean lengthIsReadable(byte[] message, int limit, int messagingVersion) + { + try + { + return Message.serializer.inferMessageSize(ByteBuffer.wrap(message), 0, limit, messagingVersion) >= 0; + } + catch (Message.InvalidLegacyProtocolMagic e) + { + throw new IllegalStateException(e); + } + } + + private static SequenceOfFrames sequenceOfMessages(Random random, float largeRatio, int messagingVersion) + { + int messageCount = 1 + random.nextInt(63); + List messages = new ArrayList<>(); + int[] cumulativeLength = new int[messageCount]; + for (int i = 0 ; i < messageCount ; ++i) + { + byte[] payload; + if (random.nextFloat() < largeRatio) payload = randomishBytes(random, 1 << 16, 1 << 17); + else payload = randomishBytes(random, 1, 1 << 16); + Message messageObj = Message.out(Verb._TEST_1, payload); + + byte[] message; + try (DataOutputBuffer out = new DataOutputBuffer(messageObj.serializedSize(messagingVersion))) + { + Message.serializer.serialize(messageObj, out, messagingVersion); + message = out.toByteArray(); + } + catch (IOException e) + { + throw new IllegalStateException(e); + } + messages.add(message); + + cumulativeLength[i] = (i == 0 ? 0 : cumulativeLength[i - 1]) + message.length; + } + + ByteBuffer frames = BufferPool.getAtLeast(cumulativeLength[messageCount - 1], BufferType.OFF_HEAP); + for (byte[] buffer : messages) + frames.put(buffer); + frames.flip(); + return new SequenceOfFrames(messages, cumulativeLength, frames); + } + + private static byte[] randomishBytes(Random random, int minLength, int maxLength) + { + byte[] bytes = new byte[minLength + random.nextInt(maxLength - minLength)]; + int runLength = 1 + random.nextInt(255); + for (int i = 0 ; i < bytes.length ; i += runLength) + { + byte b = (byte) random.nextInt(256); + Arrays.fill(bytes, i, min(bytes.length, i + runLength), b); + } + return bytes; + } + +} diff --git a/test/unit/org/apache/cassandra/net/HandshakeTest.java b/test/unit/org/apache/cassandra/net/HandshakeTest.java new file mode 100644 index 000000000000..c9d4e8715b33 --- /dev/null +++ b/test/unit/org/apache/cassandra/net/HandshakeTest.java @@ -0,0 +1,219 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.net; + +import java.nio.channels.ClosedChannelException; +import java.util.Objects; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; + +import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.BeforeClass; +import org.junit.Test; + +import io.netty.channel.EventLoop; +import io.netty.util.concurrent.Future; +import org.apache.cassandra.config.DatabaseDescriptor; +import org.apache.cassandra.locator.InetAddressAndPort; +import org.apache.cassandra.net.OutboundConnectionInitiator.Result; +import org.apache.cassandra.net.OutboundConnectionInitiator.Result.MessagingSuccess; + +import static org.apache.cassandra.net.MessagingService.VERSION_30; +import static org.apache.cassandra.net.MessagingService.VERSION_3014; +import static org.apache.cassandra.net.MessagingService.current_version; +import static org.apache.cassandra.net.MessagingService.minimum_version; +import static org.apache.cassandra.net.ConnectionType.SMALL_MESSAGES; +import static org.apache.cassandra.net.OutboundConnectionInitiator.*; + +// TODO: test failure due to exception, timeout, etc +public class HandshakeTest +{ + private static final SocketFactory factory = new SocketFactory(); + + @BeforeClass + public static void startup() + { + DatabaseDescriptor.daemonInitialization(); + } + + @AfterClass + public static void cleanup() throws InterruptedException + { + factory.shutdownNow(); + } + + private Result handshake(int req, int outMin, int outMax) throws ExecutionException, InterruptedException + { + return handshake(req, new AcceptVersions(outMin, outMax), null); + } + private Result handshake(int req, int outMin, int outMax, int inMin, int inMax) throws ExecutionException, InterruptedException + { + return handshake(req, new AcceptVersions(outMin, outMax), new AcceptVersions(inMin, inMax)); + } + private Result handshake(int req, AcceptVersions acceptOutbound, AcceptVersions acceptInbound) throws ExecutionException, InterruptedException + { + InboundSockets inbound = new InboundSockets(new InboundConnectionSettings().withAcceptMessaging(acceptInbound)); + try + { + inbound.open(); + InetAddressAndPort endpoint = inbound.sockets().stream().map(s -> s.settings.bindAddress).findFirst().get(); + EventLoop eventLoop = factory.defaultGroup().next(); + Future> future = + initiateMessaging(eventLoop, + SMALL_MESSAGES, + new OutboundConnectionSettings(endpoint) + .withAcceptVersions(acceptOutbound) + .withDefaults(ConnectionCategory.MESSAGING), + req, new AsyncPromise<>(eventLoop)); + return future.get(); + } + finally + { + inbound.close().await(1L, TimeUnit.SECONDS); + } + } + + @Test + public void testBothCurrentVersion() throws InterruptedException, ExecutionException + { + Result result = handshake(current_version, minimum_version, current_version); + Assert.assertEquals(Result.Outcome.SUCCESS, result.outcome); + result.success().channel.close(); + } + + @Test + public void testSendCompatibleOldVersion() throws InterruptedException, ExecutionException + { + Result result = handshake(current_version, current_version, current_version + 1, current_version +1, current_version + 2); + Assert.assertEquals(Result.Outcome.SUCCESS, result.outcome); + Assert.assertEquals(current_version + 1, result.success().messagingVersion); + result.success().channel.close(); + } + + @Test + public void testSendCompatibleFutureVersion() throws InterruptedException, ExecutionException + { + Result result = handshake(current_version + 1, current_version - 1, current_version + 1); + Assert.assertEquals(Result.Outcome.SUCCESS, result.outcome); + Assert.assertEquals(current_version, result.success().messagingVersion); + result.success().channel.close(); + } + + @Test + public void testSendIncompatibleFutureVersion() throws InterruptedException, ExecutionException + { + Result result = handshake(current_version + 1, current_version + 1, current_version + 1); + Assert.assertEquals(Result.Outcome.INCOMPATIBLE, result.outcome); + Assert.assertEquals(current_version, result.incompatible().closestSupportedVersion); + Assert.assertEquals(current_version, result.incompatible().maxMessagingVersion); + } + + @Test + public void testSendIncompatibleOldVersion() throws InterruptedException, ExecutionException + { + Result result = handshake(current_version + 1, current_version + 1, current_version + 1, current_version + 2, current_version + 3); + Assert.assertEquals(Result.Outcome.INCOMPATIBLE, result.outcome); + Assert.assertEquals(current_version + 2, result.incompatible().closestSupportedVersion); + Assert.assertEquals(current_version + 3, result.incompatible().maxMessagingVersion); + } + + @Test + public void testSendCompatibleMaxVersionPre40() throws InterruptedException, ExecutionException + { + Result result = handshake(VERSION_3014, VERSION_30, VERSION_3014, VERSION_30, VERSION_3014); + Assert.assertEquals(Result.Outcome.SUCCESS, result.outcome); + Assert.assertEquals(VERSION_3014, result.success().messagingVersion); + result.success().channel.close(); + } + + @Test + public void testSendCompatibleFutureVersionPre40() throws InterruptedException, ExecutionException + { + Result result = handshake(VERSION_3014, VERSION_30, VERSION_3014, VERSION_30, VERSION_30); + Assert.assertEquals(Result.Outcome.RETRY, result.outcome); + Assert.assertEquals(VERSION_30, result.retry().withMessagingVersion); + } + + @Test + public void testSendIncompatibleFutureVersionPre40() throws InterruptedException, ExecutionException + { + Result result = handshake(VERSION_3014, VERSION_3014, VERSION_3014, VERSION_30, VERSION_30); + Assert.assertEquals(Result.Outcome.INCOMPATIBLE, result.outcome); + Assert.assertEquals(-1, result.incompatible().closestSupportedVersion); + Assert.assertEquals(VERSION_30, result.incompatible().maxMessagingVersion); + } + + @Test + public void testSendCompatibleOldVersionPre40() throws InterruptedException + { + try + { + handshake(VERSION_30, VERSION_30, VERSION_3014, VERSION_3014, VERSION_3014); + Assert.fail("Should have thrown"); + } + catch (ExecutionException e) + { + Assert.assertTrue(e.getCause() instanceof ClosedChannelException); + } + } + + @Test + public void testSendIncompatibleOldVersionPre40() throws InterruptedException + { + try + { + handshake(VERSION_30, VERSION_30, VERSION_30, VERSION_3014, VERSION_3014); + Assert.fail("Should have thrown"); + } + catch (ExecutionException e) + { + Assert.assertTrue(e.getCause() instanceof ClosedChannelException); + } + } + + @Test + public void testSendCompatibleOldVersion40() throws InterruptedException, ExecutionException + { + Result result = handshake(VERSION_30, VERSION_30, VERSION_30, VERSION_30, current_version); + Assert.assertEquals(Result.Outcome.SUCCESS, result.outcome); + Assert.assertEquals(VERSION_30, result.success().messagingVersion); + } + + @Test + public void testSendIncompatibleOldVersion40() throws InterruptedException + { + try + { + Assert.fail(Objects.toString(handshake(VERSION_30, VERSION_30, VERSION_30, current_version, current_version))); + } + catch (ExecutionException e) + { + Assert.assertTrue(e.getCause() instanceof ClosedChannelException); + } + } + + @Test // fairly contrived case, but since we introduced logic for testing we need to be careful it doesn't make us worse + public void testSendToFuturePost40BelievedToBePre40() throws InterruptedException, ExecutionException + { + Result result = handshake(VERSION_30, VERSION_30, current_version, VERSION_30, current_version + 1); + Assert.assertEquals(Result.Outcome.SUCCESS, result.outcome); + Assert.assertEquals(VERSION_30, result.success().messagingVersion); + } +} diff --git a/test/unit/org/apache/cassandra/net/ManyToOneConcurrentLinkedQueueTest.java b/test/unit/org/apache/cassandra/net/ManyToOneConcurrentLinkedQueueTest.java new file mode 100644 index 000000000000..2c92a392e09f --- /dev/null +++ b/test/unit/org/apache/cassandra/net/ManyToOneConcurrentLinkedQueueTest.java @@ -0,0 +1,301 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.cassandra.net; + +import java.util.BitSet; +import java.util.NoSuchElementException; +import java.util.concurrent.Executor; +import java.util.concurrent.Executors; + +import org.junit.Test; + +import static org.junit.Assert.*; + +@SuppressWarnings("ConstantConditions") +public class ManyToOneConcurrentLinkedQueueTest +{ + private final ManyToOneConcurrentLinkedQueue queue = new ManyToOneConcurrentLinkedQueue<>(); + + @Test + public void testRelaxedIsEmptyWhenEmpty() + { + assertTrue(queue.relaxedIsEmpty()); + } + + @Test + public void testRelaxedIsEmptyWhenNotEmpty() + { + queue.offer(0); + assertFalse(queue.relaxedIsEmpty()); + } + + @Test + public void testSizeWhenEmpty() + { + assertEquals(0, queue.size()); + } + + @Test + public void testSizeWhenNotEmpty() + { + queue.offer(0); + assertEquals(1, queue.size()); + + for (int i = 1; i < 100; i++) + queue.offer(i); + assertEquals(100, queue.size()); + } + + @Test + public void testEmptyPeek() + { + assertNull(queue.peek()); + } + + @Test + public void testNonEmptyPeek() + { + queue.offer(0); + assertEquals(0, (int) queue.peek()); + } + + @Test + public void testEmptyPoll() + { + assertNull(queue.poll()); + } + + @Test + public void testNonEmptyPoll() + { + queue.offer(0); + assertEquals(0, (int) queue.poll()); + } + + @Test(expected = NoSuchElementException.class) + public void testEmptyRemove() + { + queue.remove(); + } + + @Test + public void testNonEmptyRemove() + { + queue.offer(0); + assertEquals(0, (int) queue.remove()); + } + + @Test + public void testOtherRemoveWhenEmpty() + { + assertFalse(queue.remove(0)); + } + + @Test + public void testOtherRemoveSingleNode() + { + queue.offer(0); + assertTrue(queue.remove(0)); + assertTrue(queue.isEmpty()); + } + + @Test + public void testOtherRemoveWhenFirst() + { + queue.offer(0); + queue.offer(1); + queue.offer(2); + + assertTrue(queue.remove(0)); + + assertEquals(1, (int) queue.poll()); + assertEquals(2, (int) queue.poll()); + assertNull(queue.poll()); + } + + @Test + public void testOtherRemoveFromMiddle() + { + queue.offer(0); + queue.offer(1); + queue.offer(2); + + assertTrue(queue.remove(1)); + + assertEquals(0, (int) queue.poll()); + assertEquals(2, (int) queue.poll()); + assertNull(queue.poll()); + } + + @Test + public void testOtherRemoveFromEnd() + { + queue.offer(0); + queue.offer(1); + queue.offer(2); + + assertTrue(queue.remove(2)); + + assertEquals(0, (int) queue.poll()); + assertEquals(1, (int) queue.poll()); + assertNull(queue.poll()); + } + + @Test + public void testOtherRemoveWhenDoesnNotExist() + { + queue.offer(0); + queue.offer(1); + queue.offer(2); + + assertFalse(queue.remove(3)); + + assertEquals(0, (int) queue.poll()); + assertEquals(1, (int) queue.poll()); + assertEquals(2, (int) queue.poll()); + } + + @Test + public void testTransfersInCorrectOrder() + { + for (int i = 0; i < 1024; i++) + queue.offer(i); + + for (int i = 0; i < 1024; i++) + assertEquals(i, (int) queue.poll()); + + assertTrue(queue.relaxedIsEmpty()); + } + + @Test + public void testTransfersInCorrectOrderWhenInterleaved() + { + for (int i = 0; i < 1024; i++) + { + queue.offer(i); + assertEquals(i, (int) queue.poll()); + } + + assertTrue(queue.relaxedIsEmpty()); + } + + @Test + public void testDrain() + { + for (int i = 0; i < 1024; i++) + queue.offer(i); + + class Consumer + { + private int previous = -1; + + public void accept(int i) + { + assertEquals(++previous, i); + } + } + + Consumer consumer = new Consumer(); + queue.drain(consumer::accept); + + assertEquals(1023, consumer.previous); + assertTrue(queue.relaxedIsEmpty()); + } + + @Test + public void testPeekLastAndOffer() + { + assertNull(queue.relaxedPeekLastAndOffer(0)); + for (int i = 1; i < 1024; i++) + assertEquals(i - 1, (int) queue.relaxedPeekLastAndOffer(i)); + + for (int i = 0; i < 1024; i++) + assertEquals(i, (int) queue.poll()); + + assertTrue(queue.relaxedIsEmpty()); + } + + enum Strategy + { + PEEK_AND_REMOVE, POLL + } + + @Test + public void testConcurrentlyWithPoll() + { + testConcurrently(Strategy.POLL); + } + + @Test + public void testConcurrentlyWithPeekAndRemove() + { + testConcurrently(Strategy.PEEK_AND_REMOVE); + } + + private void testConcurrently(Strategy strategy) + { + int numThreads = 4; + int numItems = 1_000_000 * numThreads; + + class Producer implements Runnable + { + private final int start, step, limit; + + private Producer(int start, int step, int limit) + { + this.start = start; + this.step = step; + this.limit = limit; + } + + public void run() + { + for (int i = start; i < limit; i += step) + queue.offer(i); + } + } + + Executor executor = Executors.newFixedThreadPool(numThreads); + for (int i = 0; i < numThreads; i++) + executor.execute(new Producer(i, numThreads, numItems)); + + BitSet itemsPolled = new BitSet(numItems); + for (int i = 0; i < numItems; i++) + { + Integer item; + switch (strategy) + { + case PEEK_AND_REMOVE: + //noinspection StatementWithEmptyBody + while ((item = queue.peek()) == null) ; + assertFalse(queue.relaxedIsEmpty()); + assertEquals(item, queue.remove()); + itemsPolled.set(item); + break; + case POLL: + //noinspection StatementWithEmptyBody + while ((item = queue.poll()) == null) ; + itemsPolled.set(item); + break; + } + } + + assertEquals(numItems, itemsPolled.cardinality()); + assertTrue(queue.relaxedIsEmpty()); + } +} \ No newline at end of file diff --git a/test/unit/org/apache/cassandra/net/Matcher.java b/test/unit/org/apache/cassandra/net/Matcher.java index 27b685f6bb26..6f8e1e7aee69 100644 --- a/test/unit/org/apache/cassandra/net/Matcher.java +++ b/test/unit/org/apache/cassandra/net/Matcher.java @@ -28,5 +28,5 @@ public interface Matcher * @param obj intercepted outgoing message * @param to destination address */ - public boolean matches(MessageOut obj, InetAddressAndPort to); + public boolean matches(Message obj, InetAddressAndPort to); } diff --git a/test/unit/org/apache/cassandra/net/MatcherResponse.java b/test/unit/org/apache/cassandra/net/MatcherResponse.java index 7a1772aec550..b2bba8b0a06c 100644 --- a/test/unit/org/apache/cassandra/net/MatcherResponse.java +++ b/test/unit/org/apache/cassandra/net/MatcherResponse.java @@ -17,15 +17,20 @@ */ package org.apache.cassandra.net; -import java.util.Collections; -import java.util.HashSet; +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; import java.util.Queue; -import java.util.Set; import java.util.concurrent.BlockingQueue; import java.util.concurrent.atomic.AtomicInteger; import java.util.function.BiFunction; +import java.util.function.BiPredicate; import java.util.function.Function; +import com.google.common.collect.Multimap; +import com.google.common.collect.Multimaps; + +import org.apache.cassandra.concurrent.StageManager; import org.apache.cassandra.locator.InetAddressAndPort; /** @@ -36,10 +41,11 @@ public class MatcherResponse { private final Matcher matcher; - private final Set sendResponses = new HashSet<>(); + private final Multimap sendResponses = + Multimaps.newListMultimap(new HashMap<>(), ArrayList::new); private final MockMessagingSpy spy = new MockMessagingSpy(); private final AtomicInteger limitCounter = new AtomicInteger(Integer.MAX_VALUE); - private IMessageSink sink; + private BiPredicate, InetAddressAndPort> sink; MatcherResponse(Matcher matcher) { @@ -51,33 +57,33 @@ public class MatcherResponse */ public MockMessagingSpy dontReply() { - return respond((MessageIn)null); + return respond((Message)null); } /** - * Respond with provided message in reply to each intercepted outbound message. - * @param message the message to use as mock reply from the cluster + * Respond with provided message in response to each intercepted outbound message. + * @param message the message to use as mock response from the cluster */ - public MockMessagingSpy respond(MessageIn message) + public MockMessagingSpy respond(Message message) { return respondN(message, Integer.MAX_VALUE); } /** - * Respond a limited number of times with the provided message in reply to each intercepted outbound message. - * @param response the message to use as mock reply from the cluster + * Respond a limited number of times with the provided message in response to each intercepted outbound message. + * @param response the message to use as mock response from the cluster * @param limit number of times to respond with message */ - public MockMessagingSpy respondN(final MessageIn response, int limit) + public MockMessagingSpy respondN(final Message response, int limit) { return respondN((in, to) -> response, limit); } /** * Respond with the message created by the provided function that will be called with each intercepted outbound message. - * @param fnResponse function to call for creating reply based on intercepted message and target address + * @param fnResponse function to call for creating response based on intercepted message and target address */ - public MockMessagingSpy respond(BiFunction, InetAddressAndPort, MessageIn> fnResponse) + public MockMessagingSpy respond(BiFunction, InetAddressAndPort, Message> fnResponse) { return respondN(fnResponse, Integer.MAX_VALUE); } @@ -86,9 +92,9 @@ public MockMessagingSpy respond(BiFunction, InetAddressAndP * Respond with message wrapping the payload object created by provided function called for each intercepted outbound message. * The target address from the intercepted message will automatically be used as the created message's sender address. * @param fnResponse function to call for creating payload object based on intercepted message and target address - * @param verb verb to use for reply message + * @param verb verb to use for response message */ - public MockMessagingSpy respondWithPayloadForEachReceiver(Function, S> fnResponse, MessagingService.Verb verb) + public MockMessagingSpy respondWithPayloadForEachReceiver(Function, S> fnResponse, Verb verb) { return respondNWithPayloadForEachReceiver(fnResponse, verb, Integer.MAX_VALUE); } @@ -98,40 +104,40 @@ public MockMessagingSpy respondWithPayloadForEachReceiver(Function MockMessagingSpy respondNWithPayloadForEachReceiver(Function, S> fnResponse, MessagingService.Verb verb, int limit) + public MockMessagingSpy respondNWithPayloadForEachReceiver(Function, S> fnResponse, Verb verb, int limit) { - return respondN((MessageOut msg, InetAddressAndPort to) -> { + return respondN((Message msg, InetAddressAndPort to) -> { S payload = fnResponse.apply(msg); if (payload == null) return null; else - return MessageIn.create(to, payload, Collections.emptyMap(), verb, MessagingService.current_version); + return Message.builder(verb, payload).from(to).build(); }, limit); } /** * Responds to each intercepted outbound message by creating a response message wrapping the next element consumed - * from the provided queue. No reply will be send when the queue has been exhausted. + * from the provided queue. No response will be send when the queue has been exhausted. * @param cannedResponses prepared payload messages to use for responses - * @param verb verb to use for reply message + * @param verb verb to use for response message */ - public MockMessagingSpy respondWithPayloadForEachReceiver(Queue cannedResponses, MessagingService.Verb verb) + public MockMessagingSpy respondWithPayloadForEachReceiver(Queue cannedResponses, Verb verb) { - return respondWithPayloadForEachReceiver((MessageOut msg) -> cannedResponses.poll(), verb); + return respondWithPayloadForEachReceiver((Message msg) -> cannedResponses.poll(), verb); } /** * Responds to each intercepted outbound message by creating a response message wrapping the next element consumed * from the provided queue. This method will block until queue elements are available. * @param cannedResponses prepared payload messages to use for responses - * @param verb verb to use for reply message + * @param verb verb to use for response message */ - public MockMessagingSpy respondWithPayloadForEachReceiver(BlockingQueue cannedResponses, MessagingService.Verb verb) + public MockMessagingSpy respondWithPayloadForEachReceiver(BlockingQueue cannedResponses, Verb verb) { - return respondWithPayloadForEachReceiver((MessageOut msg) -> { + return respondWithPayloadForEachReceiver((Message msg) -> { try { return cannedResponses.take(); @@ -146,17 +152,17 @@ public MockMessagingSpy respondWithPayloadForEachReceiver(BlockingQueue MockMessagingSpy respondN(BiFunction, InetAddressAndPort, MessageIn> fnResponse, int limit) + public MockMessagingSpy respondN(BiFunction, InetAddressAndPort, Message> fnResponse, int limit) { limitCounter.set(limit); assert sink == null: "destroy() must be called first to register new response"; - sink = new IMessageSink() + sink = new BiPredicate, InetAddressAndPort>() { - public boolean allowOutgoingMessage(MessageOut message, int id, InetAddressAndPort to) + public boolean test(Message message, InetAddressAndPort to) { // prevent outgoing message from being send in case matcher indicates a match // and instead send the mocked response @@ -169,23 +175,25 @@ public boolean allowOutgoingMessage(MessageOut message, int id, InetAddressAndPo synchronized (sendResponses) { - // I'm not sure about retry semantics regarding message/ID relationships, but I assume - // sending a message multiple times using the same ID shouldn't happen.. - assert !sendResponses.contains(id) : "ID re-use for outgoing message"; - sendResponses.add(id); + if (message.hasId()) + { + assert !sendResponses.get(message.id()).contains(to) : "ID re-use for outgoing message"; + sendResponses.put(message.id(), to); + } } // create response asynchronously to match request/response communication execution behavior new Thread(() -> { - MessageIn response = fnResponse.apply(message, to); + Message response = fnResponse.apply(message, to); if (response != null) { - CallbackInfo cb = MessagingService.instance().getRegisteredCallback(id); + RequestCallbacks.CallbackInfo cb = MessagingService.instance().callbacks.get(message.id(), to); if (cb != null) - cb.callback.response(response); + cb.callback.onResponse(response); else - MessagingService.instance().receive(response, id); + processResponse(response); + spy.matchingResponse(response); } }).start(); @@ -194,22 +202,34 @@ public boolean allowOutgoingMessage(MessageOut message, int id, InetAddressAndPo } return true; } - - public boolean allowIncomingMessage(MessageIn message, int id) - { - return true; - } }; - MessagingService.instance().addMessageSink(sink); + MessagingService.instance().outboundSink.add(sink); return spy; } + private void processResponse(Message message) + { + if (!MessagingService.instance().inboundSink.allow(message)) + return; + + StageManager.getStage(message.verb().stage).execute(() -> { + try + { + message.verb().handler().doVerb((Message)message); + } + catch (IOException e) + { + // + } + }); + } + /** * Stops currently registered response from being send. */ public void destroy() { - MessagingService.instance().removeMessageSink(sink); + MessagingService.instance().outboundSink.remove(sink); } } diff --git a/test/unit/org/apache/cassandra/net/MessageDeliveryTaskTest.java b/test/unit/org/apache/cassandra/net/MessageDeliveryTaskTest.java deleted file mode 100644 index db38efb7234d..000000000000 --- a/test/unit/org/apache/cassandra/net/MessageDeliveryTaskTest.java +++ /dev/null @@ -1,121 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.cassandra.net; - -import java.net.UnknownHostException; -import java.util.Collections; - -import org.junit.AfterClass; -import org.junit.Assert; -import org.junit.Before; -import org.junit.BeforeClass; -import org.junit.Test; - -import org.apache.cassandra.config.DatabaseDescriptor; -import org.apache.cassandra.locator.InetAddressAndPort; - -public class MessageDeliveryTaskTest -{ - private static final MockVerbHandler VERB_HANDLER = new MockVerbHandler(); - - @BeforeClass - public static void before() - { - DatabaseDescriptor.daemonInitialization(); - MessagingService.instance().registerVerbHandlers(MessagingService.Verb.UNUSED_2, VERB_HANDLER); - } - - @AfterClass - public static void after() - { - MessagingService.instance().removeVerbHandler(MessagingService.Verb.UNUSED_2); - } - - @Before - public void setUp() - { - VERB_HANDLER.reset(); - } - - @Test - public void process_HappyPath() throws UnknownHostException - { - InetAddressAndPort addr = InetAddressAndPort.getByName("127.0.0.1"); - MessageIn msg = MessageIn.create(addr, null, Collections.emptyMap(), MessagingService.Verb.UNUSED_2, 1); - MessageDeliveryTask task = new MessageDeliveryTask(msg, 42); - Assert.assertTrue(task.process()); - Assert.assertEquals(1, VERB_HANDLER.invocationCount); - } - - @Test - public void process_NullVerb() throws UnknownHostException - { - InetAddressAndPort addr = InetAddressAndPort.getByName("127.0.0.1"); - MessageIn msg = MessageIn.create(addr, null, Collections.emptyMap(), null, 1); - MessageDeliveryTask task = new MessageDeliveryTask(msg, 42); - Assert.assertFalse(task.process()); - } - - @Test - public void process_NoHandler() throws UnknownHostException - { - InetAddressAndPort addr = InetAddressAndPort.getByName("127.0.0.1"); - MessageIn msg = MessageIn.create(addr, null, Collections.emptyMap(), MessagingService.Verb.UNUSED_5, 1); - MessageDeliveryTask task = new MessageDeliveryTask(msg, 42); - Assert.assertFalse(task.process()); - } - - @Test - public void process_ExpiredDroppableMessage() throws UnknownHostException - { - InetAddressAndPort addr = InetAddressAndPort.getByName("127.0.0.1"); - - // we need any droppable verb, so just grab it from the enum itself rather than hard code a value - MessageIn msg = MessageIn.create(addr, null, Collections.emptyMap(), MessagingService.DROPPABLE_VERBS.iterator().next(), 1, 0); - MessageDeliveryTask task = new MessageDeliveryTask(msg, 42); - Assert.assertFalse(task.process()); - } - - // non-droppable message should still be processed even if they are expired - @Test - public void process_ExpiredMessage() throws UnknownHostException - { - InetAddressAndPort addr = InetAddressAndPort.getByName("127.0.0.1"); - MessageIn msg = MessageIn.create(addr, null, Collections.emptyMap(), MessagingService.Verb.UNUSED_2, 1, 0); - MessageDeliveryTask task = new MessageDeliveryTask(msg, 42); - Assert.assertTrue(task.process()); - Assert.assertEquals(1, VERB_HANDLER.invocationCount); - } - - private static class MockVerbHandler implements IVerbHandler - { - private int invocationCount; - - @Override - public void doVerb(MessageIn message, int id) - { - invocationCount++; - } - - void reset() - { - invocationCount = 0; - } - } -} diff --git a/test/unit/org/apache/cassandra/net/MessageInTest.java b/test/unit/org/apache/cassandra/net/MessageInTest.java deleted file mode 100644 index b9ea7da7f95b..000000000000 --- a/test/unit/org/apache/cassandra/net/MessageInTest.java +++ /dev/null @@ -1,66 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.cassandra.net; - -import java.io.IOException; -import java.nio.ByteBuffer; -import java.util.Collections; - -import org.junit.Assert; -import org.junit.BeforeClass; -import org.junit.Test; - -import org.apache.cassandra.config.DatabaseDescriptor; -import org.apache.cassandra.io.util.DataInputBuffer; -import org.apache.cassandra.io.util.DataInputPlus; -import org.apache.cassandra.locator.InetAddressAndPort; - -public class MessageInTest -{ - @BeforeClass - public static void before() - { - DatabaseDescriptor.daemonInitialization(); - } - - // make sure deserializing message doesn't crash with an unknown verb - @Test - public void read_NullVerb() throws IOException - { - read(null); - } - - @Test - public void read_NoSerializer() throws IOException - { - read(MessagingService.Verb.UNUSED_5); - } - - private void read(MessagingService.Verb verb) throws IOException - { - InetAddressAndPort addr = InetAddressAndPort.getByName("127.0.0.1"); - ByteBuffer buf = ByteBuffer.allocate(64); - buf.limit(buf.capacity()); - DataInputPlus dataInputBuffer = new DataInputBuffer(buf, false); - int payloadSize = 27; - Assert.assertEquals(0, buf.position()); - Assert.assertNotNull(MessageIn.read(dataInputBuffer, 1, 42, 0, addr, payloadSize, verb, Collections.emptyMap())); - Assert.assertEquals(payloadSize, buf.position()); - } -} diff --git a/test/unit/org/apache/cassandra/net/MessageTest.java b/test/unit/org/apache/cassandra/net/MessageTest.java new file mode 100644 index 000000000000..78eb4c0f3e6d --- /dev/null +++ b/test/unit/org/apache/cassandra/net/MessageTest.java @@ -0,0 +1,254 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.cassandra.net; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.UUID; +import java.util.concurrent.TimeUnit; + +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; + +import org.apache.cassandra.config.DatabaseDescriptor; +import org.apache.cassandra.exceptions.RequestFailureReason; +import org.apache.cassandra.io.IVersionedSerializer; +import org.apache.cassandra.io.util.DataInputBuffer; +import org.apache.cassandra.io.util.DataInputPlus; +import org.apache.cassandra.io.util.DataOutputBuffer; +import org.apache.cassandra.io.util.DataOutputPlus; +import org.apache.cassandra.locator.InetAddressAndPort; +import org.apache.cassandra.tracing.Tracing.TraceType; +import org.apache.cassandra.utils.FBUtilities; + +import static org.apache.cassandra.net.Message.serializer; +import static org.apache.cassandra.net.MessagingService.VERSION_3014; +import static org.apache.cassandra.net.MessagingService.VERSION_30; +import static org.apache.cassandra.net.MessagingService.VERSION_40; +import static org.apache.cassandra.net.NoPayload.noPayload; +import static org.apache.cassandra.net.ParamType.RESPOND_TO; +import static org.apache.cassandra.net.ParamType.TRACE_SESSION; +import static org.apache.cassandra.net.ParamType.TRACE_TYPE; +import static org.apache.cassandra.utils.MonotonicClock.approxTime; +import static org.junit.Assert.*; + +public class MessageTest +{ + @BeforeClass + public static void setUpClass() throws Exception + { + DatabaseDescriptor.daemonInitialization(); + DatabaseDescriptor.setCrossNodeTimeout(true); + + Verb._TEST_2.unsafeSetSerializer(() -> new IVersionedSerializer() + { + public void serialize(Integer value, DataOutputPlus out, int version) throws IOException + { + out.writeInt(value); + } + + public Integer deserialize(DataInputPlus in, int version) throws IOException + { + return in.readInt(); + } + + public long serializedSize(Integer value, int version) + { + return 4; + } + }); + } + + @AfterClass + public static void tearDownClass() throws Exception + { + Verb._TEST_2.unsafeSetSerializer(() -> NoPayload.serializer); + } + + @Test + public void testInferMessageSize() throws Exception + { + Message msg = + Message.builder(Verb._TEST_2, 37) + .withId(1) + .from(FBUtilities.getLocalAddressAndPort()) + .withCreatedAt(approxTime.now()) + .withExpiresAt(approxTime.now()) + .withFlag(MessageFlag.CALL_BACK_ON_FAILURE) + .withFlag(MessageFlag.TRACK_REPAIRED_DATA) + .withParam(TRACE_TYPE, TraceType.QUERY) + .withParam(TRACE_SESSION, UUID.randomUUID()) + .build(); + + testInferMessageSize(msg, VERSION_30); + testInferMessageSize(msg, VERSION_3014); + testInferMessageSize(msg, VERSION_40); + } + + private void testInferMessageSize(Message msg, int version) throws Exception + { + try (DataOutputBuffer out = new DataOutputBuffer()) + { + serializer.serialize(msg, out, version); + assertEquals(msg.serializedSize(version), out.getLength()); + + ByteBuffer buffer = out.buffer(); + + int payloadSize = (int) msg.verb().serializer().serializedSize(msg.payload, version); + int serializedSize = msg.serializedSize(version); + + // should return -1 - fail to infer size - for all lengths of buffer until payload length can be read + for (int limit = 0; limit < serializedSize - payloadSize; limit++) + assertEquals(-1, serializer.inferMessageSize(buffer, 0, limit, version)); + + // once payload size can be read, should correctly infer message size + for (int limit = serializedSize - payloadSize; limit < serializedSize; limit++) + assertEquals(serializedSize, serializer.inferMessageSize(buffer, 0, limit, version)); + } + } + + @Test + public void testBuilder() + { + long id = 1; + InetAddressAndPort from = FBUtilities.getLocalAddressAndPort(); + long createAtNanos = approxTime.now(); + long expiresAtNanos = createAtNanos + TimeUnit.SECONDS.toNanos(1); + TraceType traceType = TraceType.QUERY; + UUID traceSession = UUID.randomUUID(); + + Message msg = + Message.builder(Verb._TEST_1, noPayload) + .withId(1) + .from(from) + .withCreatedAt(createAtNanos) + .withExpiresAt(expiresAtNanos) + .withFlag(MessageFlag.CALL_BACK_ON_FAILURE) + .withParam(TRACE_TYPE, TraceType.QUERY) + .withParam(TRACE_SESSION, traceSession) + .build(); + + assertEquals(id, msg.id()); + assertEquals(from, msg.from()); + assertEquals(createAtNanos, msg.createdAtNanos()); + assertEquals(expiresAtNanos, msg.expiresAtNanos()); + assertTrue(msg.callBackOnFailure()); + assertFalse(msg.trackRepairedData()); + assertEquals(traceType, msg.traceType()); + assertEquals(traceSession, msg.traceSession()); + assertNull(msg.forwardTo()); + assertNull(msg.respondTo()); + } + + @Test + public void testCycleNoPayload() throws IOException + { + Message msg = + Message.builder(Verb._TEST_1, noPayload) + .withId(1) + .from(FBUtilities.getLocalAddressAndPort()) + .withCreatedAt(approxTime.now()) + .withExpiresAt(approxTime.now() + TimeUnit.SECONDS.toNanos(1)) + .withFlag(MessageFlag.CALL_BACK_ON_FAILURE) + .withParam(TRACE_SESSION, UUID.randomUUID()) + .build(); + testCycle(msg); + } + + @Test + public void testCycleWithPayload() throws Exception + { + testCycle(Message.out(Verb._TEST_2, 42)); + testCycle(Message.outWithFlag(Verb._TEST_2, 42, MessageFlag.CALL_BACK_ON_FAILURE)); + testCycle(Message.outWithFlags(Verb._TEST_2, 42, MessageFlag.CALL_BACK_ON_FAILURE, MessageFlag.TRACK_REPAIRED_DATA)); + testCycle(Message.outWithParam(1, Verb._TEST_2, 42, RESPOND_TO, FBUtilities.getBroadcastAddressAndPort())); + } + + @Test + public void testFailureResponse() throws IOException + { + long expiresAt = approxTime.now(); + Message msg = Message.failureResponse(1, expiresAt, RequestFailureReason.INCOMPATIBLE_SCHEMA); + + assertEquals(1, msg.id()); + assertEquals(Verb.FAILURE_RSP, msg.verb()); + assertEquals(expiresAt, msg.expiresAtNanos()); + assertEquals(RequestFailureReason.INCOMPATIBLE_SCHEMA, msg.payload); + assertTrue(msg.isFailureResponse()); + + testCycle(msg); + } + + private void testCycle(Message msg) throws IOException + { + testCycle(msg, VERSION_30); + testCycle(msg, VERSION_3014); + testCycle(msg, VERSION_40); + } + + // serialize (using both variants, all in one or header then rest), verify serialized size, deserialize, compare to the original + private void testCycle(Message msg, int version) throws IOException + { + try (DataOutputBuffer out = new DataOutputBuffer()) + { + serializer.serialize(msg, out, version); + assertEquals(msg.serializedSize(version), out.getLength()); + + // deserialize the message in one go, compare outcomes + try (DataInputBuffer in = new DataInputBuffer(out.buffer(), true)) + { + Message msgOut = serializer.deserialize(in, msg.from(), version); + assertEquals(0, in.available()); + assertMessagesEqual(msg, msgOut); + } + + // extract header first, then deserialize the rest of the message and compare outcomes + ByteBuffer buffer = out.buffer(); + try (DataInputBuffer in = new DataInputBuffer(out.buffer(), false)) + { + Message.Header headerOut = serializer.extractHeader(buffer, msg.from(), approxTime.now(), version); + Message msgOut = serializer.deserialize(in, headerOut, version); + assertEquals(0, in.available()); + assertMessagesEqual(msg, msgOut); + } + } + } + + private static void assertMessagesEqual(Message msg1, Message msg2) + { + assertEquals(msg1.id(), msg2.id()); + assertEquals(msg1.verb(), msg2.verb()); + assertEquals(msg1.callBackOnFailure(), msg2.callBackOnFailure()); + assertEquals(msg1.trackRepairedData(), msg2.trackRepairedData()); + assertEquals(msg1.traceType(), msg2.traceType()); + assertEquals(msg1.traceSession(), msg2.traceSession()); + assertEquals(msg1.respondTo(), msg2.respondTo()); + assertEquals(msg1.forwardTo(), msg2.forwardTo()); + + Object payload1 = msg1.payload; + Object payload2 = msg2.payload; + + if (null == payload1) + assertTrue(payload2 == noPayload || payload2 == null); + else if (null == payload2) + assertSame(payload1, noPayload); + else + assertEquals(payload1, payload2); + } +} diff --git a/test/unit/org/apache/cassandra/net/MessagingServiceTest.java b/test/unit/org/apache/cassandra/net/MessagingServiceTest.java index b56cd62950dd..76922f614f38 100644 --- a/test/unit/org/apache/cassandra/net/MessagingServiceTest.java +++ b/test/unit/org/apache/cassandra/net/MessagingServiceTest.java @@ -23,13 +23,14 @@ import java.io.IOException; import java.net.InetAddress; import java.net.UnknownHostException; +import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.regex.*; import java.util.regex.Matcher; @@ -39,20 +40,13 @@ import com.codahale.metrics.Timer; -import io.netty.bootstrap.Bootstrap; -import io.netty.channel.Channel; import org.apache.cassandra.auth.IInternodeAuthenticator; import org.apache.cassandra.config.DatabaseDescriptor; import org.apache.cassandra.config.EncryptionOptions.ServerEncryptionOptions; -import org.apache.cassandra.db.SystemKeyspace; -import org.apache.cassandra.db.monitoring.ApproximateTime; +import org.apache.cassandra.metrics.MessagingMetrics; +import org.apache.cassandra.utils.ApproximateTime; import org.apache.cassandra.exceptions.ConfigurationException; import org.apache.cassandra.locator.InetAddressAndPort; -import org.apache.cassandra.net.MessagingService.ServerChannel; -import org.apache.cassandra.net.async.NettyFactory; -import org.apache.cassandra.net.async.OutboundConnectionIdentifier; -import org.apache.cassandra.net.async.OutboundConnectionParams; -import org.apache.cassandra.net.async.OutboundMessagingPool; import org.apache.cassandra.utils.FBUtilities; import org.caffinitas.ohc.histo.EstimatedHistogram; import org.junit.After; @@ -61,6 +55,7 @@ import org.junit.BeforeClass; import org.junit.Test; +import static java.util.concurrent.TimeUnit.MILLISECONDS; import static org.junit.Assert.*; public class MessagingServiceTest @@ -83,7 +78,7 @@ public void validateConfiguration() throws ConfigurationException private static ServerEncryptionOptions originalServerEncryptionOptions; private static InetAddressAndPort originalListenAddress; - private final MessagingService messagingService = MessagingService.test(); + private final MessagingService messagingService = new MessagingService(true); @BeforeClass public static void beforeClass() throws UnknownHostException @@ -101,10 +96,10 @@ public static void beforeClass() throws UnknownHostException @Before public void before() throws UnknownHostException { - messagingService.resetDroppedMessagesMap(Integer.toString(metricScopeId++)); + messagingService.metrics.resetDroppedMessages(Integer.toString(metricScopeId++)); MockBackPressureStrategy.applied = false; - messagingService.destroyConnectionPool(InetAddressAndPort.getByName("127.0.0.2")); - messagingService.destroyConnectionPool(InetAddressAndPort.getByName("127.0.0.3")); + messagingService.closeOutbound(InetAddressAndPort.getByName("127.0.0.2")); + messagingService.closeOutbound(InetAddressAndPort.getByName("127.0.0.3")); } @After @@ -120,29 +115,32 @@ public void tearDown() @Test public void testDroppedMessages() { - MessagingService.Verb verb = MessagingService.Verb.READ; + Verb verb = Verb.READ_REQ; for (int i = 1; i <= 5000; i++) - messagingService.incrementDroppedMessages(verb, i, i % 2 == 0); + messagingService.metrics.recordDroppedMessage(verb, i, MILLISECONDS, i % 2 == 0); - List logs = messagingService.getDroppedMessagesLogs(); + List logs = new ArrayList<>(); + messagingService.metrics.resetAndConsumeDroppedErrors(logs::add); assertEquals(1, logs.size()); - Pattern regexp = Pattern.compile("READ messages were dropped in last 5000 ms: (\\d+) internal and (\\d+) cross node. Mean internal dropped latency: (\\d+) ms and Mean cross-node dropped latency: (\\d+) ms"); + Pattern regexp = Pattern.compile("READ_REQ messages were dropped in last 5000 ms: (\\d+) internal and (\\d+) cross node. Mean internal dropped latency: (\\d+) ms and Mean cross-node dropped latency: (\\d+) ms"); Matcher matcher = regexp.matcher(logs.get(0)); assertTrue(matcher.find()); assertEquals(2500, Integer.parseInt(matcher.group(1))); assertEquals(2500, Integer.parseInt(matcher.group(2))); assertTrue(Integer.parseInt(matcher.group(3)) > 0); assertTrue(Integer.parseInt(matcher.group(4)) > 0); - assertEquals(5000, (int) messagingService.getDroppedMessages().get(verb.toString())); + assertEquals(5000, (int) messagingService.metrics.getDroppedMessages().get(verb.toString())); - logs = messagingService.getDroppedMessagesLogs(); + logs.clear(); + messagingService.metrics.resetAndConsumeDroppedErrors(logs::add); assertEquals(0, logs.size()); for (int i = 0; i < 2500; i++) - messagingService.incrementDroppedMessages(verb, i, i % 2 == 0); + messagingService.metrics.recordDroppedMessage(verb, i, MILLISECONDS, i % 2 == 0); - logs = messagingService.getDroppedMessagesLogs(); + logs.clear(); + messagingService.metrics.resetAndConsumeDroppedErrors(logs::add); assertEquals(1, logs.size()); matcher = regexp.matcher(logs.get(0)); assertTrue(matcher.find()); @@ -150,57 +148,53 @@ public void testDroppedMessages() assertEquals(1250, Integer.parseInt(matcher.group(2))); assertTrue(Integer.parseInt(matcher.group(3)) > 0); assertTrue(Integer.parseInt(matcher.group(4)) > 0); - assertEquals(7500, (int) messagingService.getDroppedMessages().get(verb.toString())); + assertEquals(7500, (int) messagingService.metrics.getDroppedMessages().get(verb.toString())); } @Test public void testDCLatency() throws Exception { int latency = 100; - ConcurrentHashMap dcLatency = MessagingService.instance().metrics.dcLatency; + ConcurrentHashMap dcLatency = MessagingService.instance().metrics.dcLatency; dcLatency.clear(); - long now = ApproximateTime.currentTimeMillis(); + long now = System.currentTimeMillis(); long sentAt = now - latency; assertNull(dcLatency.get("datacenter1")); addDCLatency(sentAt, now); assertNotNull(dcLatency.get("datacenter1")); - assertEquals(1, dcLatency.get("datacenter1").getCount()); - long expectedBucket = bucketOffsets[Math.abs(Arrays.binarySearch(bucketOffsets, TimeUnit.MILLISECONDS.toNanos(latency))) - 1]; - assertEquals(expectedBucket, dcLatency.get("datacenter1").getSnapshot().getMax()); + assertEquals(1, dcLatency.get("datacenter1").dcLatency.getCount()); + long expectedBucket = bucketOffsets[Math.abs(Arrays.binarySearch(bucketOffsets, MILLISECONDS.toNanos(latency))) - 1]; + assertEquals(expectedBucket, dcLatency.get("datacenter1").dcLatency.getSnapshot().getMax()); } @Test - public void testNegativeDCLatency() throws Exception + public void testNegativeDCLatency() { + MessagingMetrics.DCLatencyRecorder updater = MessagingService.instance().metrics.internodeLatencyRecorder(InetAddressAndPort.getLocalHost()); + // if clocks are off should just not track anything int latency = -100; - ConcurrentHashMap dcLatency = MessagingService.instance().metrics.dcLatency; - dcLatency.clear(); - - long now = ApproximateTime.currentTimeMillis(); + long now = System.currentTimeMillis(); long sentAt = now - latency; - assertNull(dcLatency.get("datacenter1")); - addDCLatency(sentAt, now); - assertNull(dcLatency.get("datacenter1")); + long count = updater.dcLatency.getCount(); + updater.accept(now - sentAt, MILLISECONDS); + // negative value shoudln't be recorded + assertEquals(count, updater.dcLatency.getCount()); } @Test - public void testQueueWaitLatency() throws Exception + public void testQueueWaitLatency() { int latency = 100; - String verb = MessagingService.Verb.MUTATION.toString(); - - ConcurrentHashMap queueWaitLatency = MessagingService.instance().metrics.queueWaitLatency; - queueWaitLatency.clear(); + Verb verb = Verb.MUTATION_REQ; - assertNull(queueWaitLatency.get(verb)); - MessagingService.instance().metrics.addQueueWaitTime(verb, latency); - assertNotNull(queueWaitLatency.get(verb)); + Map queueWaitLatency = MessagingService.instance().metrics.internalLatency; + MessagingService.instance().metrics.recordInternalLatency(verb, latency, MILLISECONDS); assertEquals(1, queueWaitLatency.get(verb).getCount()); - long expectedBucket = bucketOffsets[Math.abs(Arrays.binarySearch(bucketOffsets, TimeUnit.MILLISECONDS.toNanos(latency))) - 1]; + long expectedBucket = bucketOffsets[Math.abs(Arrays.binarySearch(bucketOffsets, MILLISECONDS.toNanos(latency))) - 1]; assertEquals(expectedBucket, queueWaitLatency.get(verb).getSnapshot().getMax()); } @@ -208,13 +202,13 @@ public void testQueueWaitLatency() throws Exception public void testNegativeQueueWaitLatency() throws Exception { int latency = -100; - String verb = MessagingService.Verb.MUTATION.toString(); + Verb verb = Verb.MUTATION_REQ; - ConcurrentHashMap queueWaitLatency = MessagingService.instance().metrics.queueWaitLatency; + Map queueWaitLatency = MessagingService.instance().metrics.internalLatency; queueWaitLatency.clear(); assertNull(queueWaitLatency.get(verb)); - MessagingService.instance().metrics.addQueueWaitTime(verb, latency); + MessagingService.instance().metrics.recordInternalLatency(verb, latency, MILLISECONDS); assertNull(queueWaitLatency.get(verb)); } @@ -222,9 +216,9 @@ public void testNegativeQueueWaitLatency() throws Exception public void testUpdatesBackPressureOnSendWhenEnabledAndWithSupportedCallback() throws UnknownHostException { MockBackPressureStrategy.MockBackPressureState backPressureState = (MockBackPressureStrategy.MockBackPressureState) messagingService.getBackPressureState(InetAddressAndPort.getByName("127.0.0.2")); - IAsyncCallback bpCallback = new BackPressureCallback(); - IAsyncCallback noCallback = new NoBackPressureCallback(); - MessageOut ignored = null; + RequestCallback bpCallback = new BackPressureCallback(); + RequestCallback noCallback = new NoBackPressureCallback(); + Message ignored = null; DatabaseDescriptor.setBackPressureEnabled(true); messagingService.updateBackPressureOnSend(InetAddressAndPort.getByName("127.0.0.2"), noCallback, ignored); @@ -243,8 +237,8 @@ public void testUpdatesBackPressureOnSendWhenEnabledAndWithSupportedCallback() t public void testUpdatesBackPressureOnReceiveWhenEnabledAndWithSupportedCallback() throws UnknownHostException { MockBackPressureStrategy.MockBackPressureState backPressureState = (MockBackPressureStrategy.MockBackPressureState) messagingService.getBackPressureState(InetAddressAndPort.getByName("127.0.0.2")); - IAsyncCallback bpCallback = new BackPressureCallback(); - IAsyncCallback noCallback = new NoBackPressureCallback(); + RequestCallback bpCallback = new BackPressureCallback(); + RequestCallback noCallback = new NoBackPressureCallback(); boolean timeout = false; DatabaseDescriptor.setBackPressureEnabled(true); @@ -267,8 +261,8 @@ public void testUpdatesBackPressureOnReceiveWhenEnabledAndWithSupportedCallback( public void testUpdatesBackPressureOnTimeoutWhenEnabledAndWithSupportedCallback() throws UnknownHostException { MockBackPressureStrategy.MockBackPressureState backPressureState = (MockBackPressureStrategy.MockBackPressureState) messagingService.getBackPressureState(InetAddressAndPort.getByName("127.0.0.2")); - IAsyncCallback bpCallback = new BackPressureCallback(); - IAsyncCallback noCallback = new NoBackPressureCallback(); + RequestCallback bpCallback = new BackPressureCallback(); + RequestCallback noCallback = new NoBackPressureCallback(); boolean timeout = true; DatabaseDescriptor.setBackPressureEnabled(true); @@ -309,7 +303,7 @@ public void testDoesntApplyBackPressureToBroadcastAddress() throws UnknownHostEx private static void addDCLatency(long sentAt, long nowTime) throws IOException { - MessageIn.deriveConstructionTime(InetAddressAndPort.getLocalHost(), (int) sentAt, nowTime); + MessagingService.instance().metrics.internodeLatencyRecorder(InetAddressAndPort.getLocalHost()).accept(nowTime - sentAt, MILLISECONDS); } public static class MockBackPressureStrategy implements BackPressureStrategy @@ -346,7 +340,7 @@ private MockBackPressureState(InetAddressAndPort host) } @Override - public void onMessageSent(MessageOut message) + public void onMessageSent(Message message) { onSend = true; } @@ -377,7 +371,7 @@ public InetAddressAndPort getHost() } } - private static class BackPressureCallback implements IAsyncCallback + private static class BackPressureCallback implements RequestCallback { @Override public boolean supportsBackPressure() @@ -386,19 +380,13 @@ public boolean supportsBackPressure() } @Override - public boolean isLatencyForSnitch() - { - return false; - } - - @Override - public void response(MessageIn msg) + public void onResponse(Message msg) { throw new UnsupportedOperationException("Not supported."); } } - private static class NoBackPressureCallback implements IAsyncCallback + private static class NoBackPressureCallback implements RequestCallback { @Override public boolean supportsBackPressure() @@ -407,13 +395,7 @@ public boolean supportsBackPressure() } @Override - public boolean isLatencyForSnitch() - { - return false; - } - - @Override - public void response(MessageIn msg) + public void onResponse(Message msg) { throw new UnsupportedOperationException("Not supported."); } @@ -433,167 +415,111 @@ public void testFailedInternodeAuth() throws Exception InetAddressAndPort address = InetAddressAndPort.getByName("127.0.0.250"); //Should return null - MessageOut messageOut = new MessageOut(MessagingService.Verb.GOSSIP_DIGEST_ACK); + Message messageOut = Message.out(Verb.ECHO_REQ, NoPayload.noPayload); assertFalse(ms.isConnected(address, messageOut)); //Should tolerate null - ms.convict(address); - ms.sendOneWay(messageOut, address); - } - - @Test - public void testOutboundMessagingConnectionCleansUp() throws Exception - { - MessagingService ms = MessagingService.instance(); - InetAddressAndPort local = InetAddressAndPort.getByNameOverrideDefaults("127.0.0.1", 9876); - InetAddressAndPort remote = InetAddressAndPort.getByNameOverrideDefaults("127.0.0.2", 9876); - - OutboundMessagingPool pool = new OutboundMessagingPool(remote, local, null, new MockBackPressureStrategy(null).newState(remote), ALLOW_NOTHING_AUTHENTICATOR); - ms.channelManagers.put(remote, pool); - pool.sendMessage(new MessageOut(MessagingService.Verb.GOSSIP_DIGEST_ACK), 0); - assertFalse(ms.channelManagers.containsKey(remote)); - } - - @Test - public void reconnectWithNewIp() throws Exception - { - InetAddressAndPort publicIp = InetAddressAndPort.getByName("127.0.0.2"); - InetAddressAndPort privateIp = InetAddressAndPort.getByName("127.0.0.3"); - - // reset the preferred IP value, for good test hygene - SystemKeyspace.updatePreferredIP(publicIp, publicIp); - - // create pool/conn with public addr - Assert.assertEquals(publicIp, messagingService.getCurrentEndpoint(publicIp)); - messagingService.reconnectWithNewIp(publicIp, privateIp); - Assert.assertEquals(privateIp, messagingService.getCurrentEndpoint(publicIp)); - - messagingService.destroyConnectionPool(publicIp); - - // recreate the pool/conn, and make sure the preferred ip addr is used - Assert.assertEquals(privateIp, messagingService.getCurrentEndpoint(publicIp)); - } - - @Test - public void testCloseInboundConnections() throws UnknownHostException, InterruptedException - { - try - { - messagingService.listen(); - Assert.assertTrue(messagingService.isListening()); - Assert.assertTrue(messagingService.serverChannels.size() > 0); - for (ServerChannel serverChannel : messagingService.serverChannels) - Assert.assertEquals(0, serverChannel.size()); - - // now, create a connection and make sure it's in a channel group - InetAddressAndPort server = FBUtilities.getBroadcastAddressAndPort(); - OutboundConnectionIdentifier id = OutboundConnectionIdentifier.small(InetAddressAndPort.getByNameOverrideDefaults("127.0.0.2", 0), server); - - CountDownLatch latch = new CountDownLatch(1); - OutboundConnectionParams params = OutboundConnectionParams.builder() - .mode(NettyFactory.Mode.MESSAGING) - .sendBufferSize(1 << 10) - .connectionId(id) - .callback(handshakeResult -> latch.countDown()) - .protocolVersion(MessagingService.current_version) - .build(); - Bootstrap bootstrap = NettyFactory.instance.createOutboundBootstrap(params); - Channel channel = bootstrap.connect().awaitUninterruptibly().channel(); - Assert.assertNotNull(channel); - latch.await(1, TimeUnit.SECONDS); // allow the netty pipeline/c* handshake to get set up - - int connectCount = 0; - for (ServerChannel serverChannel : messagingService.serverChannels) - connectCount += serverChannel.size(); - Assert.assertTrue(connectCount > 0); - } - finally - { - // last, shutdown the MS and make sure connections are removed - messagingService.shutdown(true); - for (ServerChannel serverChannel : messagingService.serverChannels) - Assert.assertEquals(0, serverChannel.size()); - messagingService.clearServerChannels(); - } - } + ms.closeOutbound(address); + ms.send(messageOut, address); + } + +// @Test +// public void reconnectWithNewIp() throws Exception +// { +// InetAddressAndPort publicIp = InetAddressAndPort.getByName("127.0.0.2"); +// InetAddressAndPort privateIp = InetAddressAndPort.getByName("127.0.0.3"); +// +// // reset the preferred IP value, for good test hygene +// SystemKeyspace.updatePreferredIP(publicIp, publicIp); +// +// // create pool/conn with public addr +// Assert.assertEquals(publicIp, messagingService.getCurrentEndpoint(publicIp)); +// messagingService.maybeReconnectWithNewIp(publicIp, privateIp).await(1L, TimeUnit.SECONDS); +// Assert.assertEquals(privateIp, messagingService.getCurrentEndpoint(publicIp)); +// +// messagingService.closeOutbound(publicIp); +// +// // recreate the pool/conn, and make sure the preferred ip addr is used +// Assert.assertEquals(privateIp, messagingService.getCurrentEndpoint(publicIp)); +// } @Test - public void listenPlainConnection() + public void listenPlainConnection() throws InterruptedException { - ServerEncryptionOptions serverEncryptionOptions = new ServerEncryptionOptions(); - serverEncryptionOptions.enabled = false; + ServerEncryptionOptions serverEncryptionOptions = new ServerEncryptionOptions() + .withEnabled(false); listen(serverEncryptionOptions, false); } @Test - public void listenPlainConnectionWithBroadcastAddr() + public void listenPlainConnectionWithBroadcastAddr() throws InterruptedException { - ServerEncryptionOptions serverEncryptionOptions = new ServerEncryptionOptions(); - serverEncryptionOptions.enabled = false; + ServerEncryptionOptions serverEncryptionOptions = new ServerEncryptionOptions() + .withEnabled(false); listen(serverEncryptionOptions, true); } @Test - public void listenRequiredSecureConnection() + public void listenRequiredSecureConnection() throws InterruptedException { - ServerEncryptionOptions serverEncryptionOptions = new ServerEncryptionOptions(); - serverEncryptionOptions.enabled = true; - serverEncryptionOptions.optional = false; - serverEncryptionOptions.enable_legacy_ssl_storage_port = false; + ServerEncryptionOptions serverEncryptionOptions = new ServerEncryptionOptions() + .withEnabled(true) + .withOptional(false) + .withLegacySslStoragePort(false); listen(serverEncryptionOptions, false); } @Test - public void listenRequiredSecureConnectionWithBroadcastAddr() + public void listenRequiredSecureConnectionWithBroadcastAddr() throws InterruptedException { - ServerEncryptionOptions serverEncryptionOptions = new ServerEncryptionOptions(); - serverEncryptionOptions.enabled = true; - serverEncryptionOptions.optional = false; - serverEncryptionOptions.enable_legacy_ssl_storage_port = false; + ServerEncryptionOptions serverEncryptionOptions = new ServerEncryptionOptions() + .withEnabled(true) + .withOptional(false) + .withLegacySslStoragePort(false); listen(serverEncryptionOptions, true); } @Test - public void listenRequiredSecureConnectionWithLegacyPort() + public void listenRequiredSecureConnectionWithLegacyPort() throws InterruptedException { - ServerEncryptionOptions serverEncryptionOptions = new ServerEncryptionOptions(); - serverEncryptionOptions.enabled = true; - serverEncryptionOptions.optional = false; - serverEncryptionOptions.enable_legacy_ssl_storage_port = true; + ServerEncryptionOptions serverEncryptionOptions = new ServerEncryptionOptions() + .withEnabled(true) + .withOptional(false) + .withLegacySslStoragePort(true); listen(serverEncryptionOptions, false); } @Test - public void listenRequiredSecureConnectionWithBroadcastAddrAndLegacyPort() + public void listenRequiredSecureConnectionWithBroadcastAddrAndLegacyPort() throws InterruptedException { - ServerEncryptionOptions serverEncryptionOptions = new ServerEncryptionOptions(); - serverEncryptionOptions.enabled = true; - serverEncryptionOptions.optional = false; - serverEncryptionOptions.enable_legacy_ssl_storage_port = true; + ServerEncryptionOptions serverEncryptionOptions = new ServerEncryptionOptions() + .withEnabled(true) + .withOptional(false) + .withLegacySslStoragePort(true); listen(serverEncryptionOptions, true); } @Test - public void listenOptionalSecureConnection() + public void listenOptionalSecureConnection() throws InterruptedException { - ServerEncryptionOptions serverEncryptionOptions = new ServerEncryptionOptions(); - serverEncryptionOptions.enabled = true; - serverEncryptionOptions.optional = true; + ServerEncryptionOptions serverEncryptionOptions = new ServerEncryptionOptions() + .withEnabled(true) + .withOptional(true); listen(serverEncryptionOptions, false); } @Test - public void listenOptionalSecureConnectionWithBroadcastAddr() + public void listenOptionalSecureConnectionWithBroadcastAddr() throws InterruptedException { - ServerEncryptionOptions serverEncryptionOptions = new ServerEncryptionOptions(); - serverEncryptionOptions.enabled = true; - serverEncryptionOptions.optional = true; + ServerEncryptionOptions serverEncryptionOptions = new ServerEncryptionOptions() + .withEnabled(true) + .withOptional(true); listen(serverEncryptionOptions, true); } - private void listen(ServerEncryptionOptions serverEncryptionOptions, boolean listenOnBroadcastAddr) + private void listen(ServerEncryptionOptions serverEncryptionOptions, boolean listenOnBroadcastAddr) throws InterruptedException { - InetAddress listenAddress = null; + InetAddress listenAddress = FBUtilities.getJustLocalAddress(); if (listenOnBroadcastAddr) { DatabaseDescriptor.setShouldListenOnBroadcastAddress(true); @@ -602,111 +528,93 @@ private void listen(ServerEncryptionOptions serverEncryptionOptions, boolean lis FBUtilities.reset(); } + InboundConnectionSettings settings = new InboundConnectionSettings() + .withEncryption(serverEncryptionOptions); + InboundSockets connections = new InboundSockets(settings); try { - messagingService.listen(serverEncryptionOptions); - Assert.assertTrue(messagingService.isListening()); - int expectedListeningCount = NettyFactory.determineAcceptGroupSize(serverEncryptionOptions); - Assert.assertEquals(expectedListeningCount, messagingService.serverChannels.size()); + connections.open().await(); + Assert.assertTrue(connections.isListening()); - if (!serverEncryptionOptions.enabled) - { - // make sure no channel is using TLS - for (ServerChannel serverChannel : messagingService.serverChannels) - Assert.assertEquals(ServerChannel.SecurityLevel.NONE, serverChannel.getSecurityLevel()); - } - else + Set expect = new HashSet<>(); + expect.add(InetAddressAndPort.getByAddressOverrideDefaults(listenAddress, DatabaseDescriptor.getStoragePort())); + if (settings.encryption.enable_legacy_ssl_storage_port) + expect.add(InetAddressAndPort.getByAddressOverrideDefaults(listenAddress, DatabaseDescriptor.getSSLStoragePort())); + if (listenOnBroadcastAddr) { - final int legacySslPort = DatabaseDescriptor.getSSLStoragePort(); - boolean foundLegacyListenSslAddress = false; - for (ServerChannel serverChannel : messagingService.serverChannels) - { - if (serverEncryptionOptions.optional) - Assert.assertEquals(ServerChannel.SecurityLevel.OPTIONAL, serverChannel.getSecurityLevel()); - else - Assert.assertEquals(ServerChannel.SecurityLevel.REQUIRED, serverChannel.getSecurityLevel()); - - if (serverEncryptionOptions.enable_legacy_ssl_storage_port) - { - if (legacySslPort == serverChannel.getAddress().port) - { - foundLegacyListenSslAddress = true; - Assert.assertEquals(ServerChannel.SecurityLevel.REQUIRED, serverChannel.getSecurityLevel()); - } - } - } - - if (serverEncryptionOptions.enable_legacy_ssl_storage_port && !foundLegacyListenSslAddress) - Assert.fail("failed to find legacy ssl listen address"); + expect.add(InetAddressAndPort.getByAddressOverrideDefaults(FBUtilities.getBroadcastAddressAndPort().address, DatabaseDescriptor.getStoragePort())); + if (settings.encryption.enable_legacy_ssl_storage_port) + expect.add(InetAddressAndPort.getByAddressOverrideDefaults(FBUtilities.getBroadcastAddressAndPort().address, DatabaseDescriptor.getSSLStoragePort())); } - // check the optional listen address - if (listenOnBroadcastAddr) + Assert.assertEquals(expect.size(), connections.sockets().size()); + + final int legacySslPort = DatabaseDescriptor.getSSLStoragePort(); + for (InboundSockets.InboundSocket socket : connections.sockets()) { - int expectedCount = (serverEncryptionOptions.enabled && serverEncryptionOptions.enable_legacy_ssl_storage_port) ? 2 : 1; - int found = 0; - for (ServerChannel serverChannel : messagingService.serverChannels) - { - if (serverChannel.getAddress().address.equals(listenAddress)) - found++; - } - - Assert.assertEquals(expectedCount, found); + Assert.assertEquals(serverEncryptionOptions.enabled, socket.settings.encryption.enabled); + Assert.assertEquals(serverEncryptionOptions.optional, socket.settings.encryption.optional); + if (!serverEncryptionOptions.enabled) + Assert.assertFalse(legacySslPort == socket.settings.bindAddress.port); + if (legacySslPort == socket.settings.bindAddress.port) + Assert.assertFalse(socket.settings.encryption.optional); + Assert.assertTrue(socket.settings.bindAddress.toString(), expect.remove(socket.settings.bindAddress)); } } finally { - messagingService.shutdown(true); - messagingService.clearServerChannels(); - Assert.assertEquals(0, messagingService.serverChannels.size()); + connections.close().await(); + Assert.assertFalse(connections.isListening()); } } - @Test - public void getPreferredRemoteAddrUsesPrivateIp() throws UnknownHostException - { - MessagingService ms = MessagingService.instance(); - InetAddressAndPort local = InetAddressAndPort.getByNameOverrideDefaults("127.0.0.4", 7000); - InetAddressAndPort remote = InetAddressAndPort.getByNameOverrideDefaults("127.0.0.151", 7000); - InetAddressAndPort privateIp = InetAddressAndPort.getByName("127.0.0.6"); - - OutboundMessagingPool pool = new OutboundMessagingPool(privateIp, local, null, - new MockBackPressureStrategy(null).newState(remote), - ALLOW_NOTHING_AUTHENTICATOR); - ms.channelManagers.put(remote, pool); - - Assert.assertEquals(privateIp, ms.getPreferredRemoteAddr(remote)); - } - - @Test - public void getPreferredRemoteAddrUsesPreferredIp() throws UnknownHostException - { - MessagingService ms = MessagingService.instance(); - InetAddressAndPort remote = InetAddressAndPort.getByNameOverrideDefaults("127.0.0.115", 7000); - - InetAddressAndPort preferredIp = InetAddressAndPort.getByName("127.0.0.16"); - SystemKeyspace.updatePreferredIP(remote, preferredIp); - - Assert.assertEquals(preferredIp, ms.getPreferredRemoteAddr(remote)); - } - - @Test - public void getPreferredRemoteAddrUsesPrivateIpOverridesPreferredIp() throws UnknownHostException - { - MessagingService ms = MessagingService.instance(); - InetAddressAndPort local = InetAddressAndPort.getByNameOverrideDefaults("127.0.0.4", 7000); - InetAddressAndPort remote = InetAddressAndPort.getByNameOverrideDefaults("127.0.0.105", 7000); - InetAddressAndPort privateIp = InetAddressAndPort.getByName("127.0.0.6"); - - OutboundMessagingPool pool = new OutboundMessagingPool(privateIp, local, null, - new MockBackPressureStrategy(null).newState(remote), - ALLOW_NOTHING_AUTHENTICATOR); - ms.channelManagers.put(remote, pool); - - InetAddressAndPort preferredIp = InetAddressAndPort.getByName("127.0.0.16"); - SystemKeyspace.updatePreferredIP(remote, preferredIp); - - Assert.assertEquals(privateIp, ms.getPreferredRemoteAddr(remote)); - } +// @Test +// public void getPreferredRemoteAddrUsesPrivateIp() throws UnknownHostException +// { +// MessagingService ms = MessagingService.instance(); +// InetAddressAndPort remote = InetAddressAndPort.getByNameOverrideDefaults("127.0.0.151", 7000); +// InetAddressAndPort privateIp = InetAddressAndPort.getByName("127.0.0.6"); +// +// OutboundConnectionSettings template = new OutboundConnectionSettings(remote) +// .withConnectTo(privateIp) +// .withAuthenticator(ALLOW_NOTHING_AUTHENTICATOR); +// OutboundConnections pool = new OutboundConnections(template, new MockBackPressureStrategy(null).newState(remote)); +// ms.channelManagers.put(remote, pool); +// +// Assert.assertEquals(privateIp, ms.getPreferredRemoteAddr(remote)); +// } +// +// @Test +// public void getPreferredRemoteAddrUsesPreferredIp() throws UnknownHostException +// { +// MessagingService ms = MessagingService.instance(); +// InetAddressAndPort remote = InetAddressAndPort.getByNameOverrideDefaults("127.0.0.115", 7000); +// +// InetAddressAndPort preferredIp = InetAddressAndPort.getByName("127.0.0.16"); +// SystemKeyspace.updatePreferredIP(remote, preferredIp); +// +// Assert.assertEquals(preferredIp, ms.getPreferredRemoteAddr(remote)); +// } +// +// @Test +// public void getPreferredRemoteAddrUsesPrivateIpOverridesPreferredIp() throws UnknownHostException +// { +// MessagingService ms = MessagingService.instance(); +// InetAddressAndPort local = InetAddressAndPort.getByNameOverrideDefaults("127.0.0.4", 7000); +// InetAddressAndPort remote = InetAddressAndPort.getByNameOverrideDefaults("127.0.0.105", 7000); +// InetAddressAndPort privateIp = InetAddressAndPort.getByName("127.0.0.6"); +// +// OutboundConnectionSettings template = new OutboundConnectionSettings(remote) +// .withConnectTo(privateIp) +// .withAuthenticator(ALLOW_NOTHING_AUTHENTICATOR); +// +// OutboundConnections pool = new OutboundConnections(template, new MockBackPressureStrategy(null).newState(remote)); +// ms.channelManagers.put(remote, pool); +// +// InetAddressAndPort preferredIp = InetAddressAndPort.getByName("127.0.0.16"); +// SystemKeyspace.updatePreferredIP(remote, preferredIp); +// +// Assert.assertEquals(privateIp, ms.getPreferredRemoteAddr(remote)); +// } } diff --git a/test/unit/org/apache/cassandra/net/MockMessagingService.java b/test/unit/org/apache/cassandra/net/MockMessagingService.java index 79edae8398ad..3749bafba70e 100644 --- a/test/unit/org/apache/cassandra/net/MockMessagingService.java +++ b/test/unit/org/apache/cassandra/net/MockMessagingService.java @@ -26,7 +26,7 @@ * Starting point for mocking {@link MessagingService} interactions. Outgoing messages can be * intercepted by first creating a {@link MatcherResponse} by calling {@link MockMessagingService#when(Matcher)}. * Alternatively {@link Matcher}s can be created by using helper methods such as {@link #to(InetAddressAndPort)}, - * {@link #verb(MessagingService.Verb)} or {@link #payload(Predicate)} and may also be + * {@link #verb(Verb)} or {@link #payload(Predicate)} and may also be * nested using {@link MockMessagingService#all(Matcher[])} or {@link MockMessagingService#any(Matcher[])}. * After each test, {@link MockMessagingService#cleanup()} must be called for free listeners registered * in {@link MessagingService}. @@ -47,12 +47,13 @@ public static MatcherResponse when(Matcher matcher) } /** - * Unsubscribes any handlers added by calling {@link MessagingService#addMessageSink(IMessageSink)}. + * Unsubscribes any handlers. * This should be called after each test. */ public static void cleanup() { - MessagingService.instance().clearMessageSinks(); + MessagingService.instance().outboundSink.clear(); + MessagingService.instance().inboundSink.clear(); } /** @@ -92,15 +93,15 @@ public static Matcher to(Predicate predi * Creates a matcher that will indicate if the verb of the outgoing message equals the * provided value. */ - public static Matcher verb(MessagingService.Verb verb) + public static Matcher verb(Verb verb) { - return (in, to) -> in.verb == verb; + return (in, to) -> in.verb() == verb; } /** * Creates a matcher based on the result of the provided predicate called with the outgoing message. */ - public static Matcher message(Predicate> fn) + public static Matcher message(Predicate> fn) { return (msg, to) -> fn.test(msg); } @@ -126,7 +127,7 @@ public static Matcher not(Matcher matcher) */ public static Matcher all(Matcher... matchers) { - return (MessageOut out, InetAddressAndPort to) -> { + return (Message out, InetAddressAndPort to) -> { for (Matcher matcher : matchers) { if (!matcher.matches(out, to)) @@ -141,7 +142,7 @@ public static Matcher all(Matcher... matchers) */ public static Matcher any(Matcher... matchers) { - return (MessageOut out, InetAddressAndPort to) -> { + return (Message out, InetAddressAndPort to) -> { for (Matcher matcher : matchers) { if (matcher.matches(out, to)) diff --git a/test/unit/org/apache/cassandra/net/MockMessagingServiceTest.java b/test/unit/org/apache/cassandra/net/MockMessagingServiceTest.java index 8d0f91bf8d55..e4787f74ed38 100644 --- a/test/unit/org/apache/cassandra/net/MockMessagingServiceTest.java +++ b/test/unit/org/apache/cassandra/net/MockMessagingServiceTest.java @@ -17,7 +17,6 @@ */ package org.apache.cassandra.net; -import java.util.Collections; import java.util.concurrent.ExecutionException; import org.junit.Before; @@ -26,14 +25,15 @@ import org.apache.cassandra.SchemaLoader; import org.apache.cassandra.exceptions.ConfigurationException; -import org.apache.cassandra.gms.EchoMessage; import org.apache.cassandra.service.StorageService; import org.apache.cassandra.utils.FBUtilities; +import static org.apache.cassandra.net.Verb.ECHO_REQ; import static org.apache.cassandra.net.MockMessagingService.all; import static org.apache.cassandra.net.MockMessagingService.to; import static org.apache.cassandra.net.MockMessagingService.verb; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertSame; import static org.junit.Assert.assertTrue; public class MockMessagingServiceTest @@ -54,40 +54,28 @@ public void cleanup() @Test public void testRequestResponse() throws InterruptedException, ExecutionException { - // echo message that we like to mock as incoming reply for outgoing echo message - MessageIn echoMessageIn = MessageIn.create(FBUtilities.getBroadcastAddressAndPort(), - EchoMessage.instance, - Collections.emptyMap(), - MessagingService.Verb.ECHO, - MessagingService.current_version); + // echo message that we like to mock as incoming response for outgoing echo message + Message echoMessage = Message.out(ECHO_REQ, NoPayload.noPayload); MockMessagingSpy spy = MockMessagingService .when( all( to(FBUtilities.getBroadcastAddressAndPort()), - verb(MessagingService.Verb.ECHO) + verb(ECHO_REQ) ) ) - .respond(echoMessageIn); + .respond(echoMessage); - MessageOut echoMessageOut = new MessageOut<>(MessagingService.Verb.ECHO, EchoMessage.instance, EchoMessage.serializer); - MessagingService.instance().sendRR(echoMessageOut, FBUtilities.getBroadcastAddressAndPort(), new IAsyncCallback() + Message echoMessageOut = Message.out(ECHO_REQ, NoPayload.noPayload); + MessagingService.instance().sendWithCallback(echoMessageOut, FBUtilities.getBroadcastAddressAndPort(), msg -> { - public void response(MessageIn msg) - { - assertEquals(MessagingService.Verb.ECHO, msg.verb); - assertEquals(echoMessageIn.payload, msg.payload); - } - - public boolean isLatencyForSnitch() - { - return false; - } + assertEquals(ECHO_REQ, msg.verb()); + assertEquals(echoMessage.payload, msg.payload); }); // we must have intercepted the outgoing message at this point - MessageOut msg = spy.captureMessageOut().get(); + Message msg = spy.captureMessageOut().get(); assertEquals(1, spy.messagesIntercepted); - assertTrue(msg == echoMessageOut); + assertSame(echoMessage.payload, msg.payload); // and return a mocked response assertEquals(1, spy.mockedMessageResponses); diff --git a/test/unit/org/apache/cassandra/net/MockMessagingSpy.java b/test/unit/org/apache/cassandra/net/MockMessagingSpy.java index 80bdb39a0f1a..bf4c2267c8e2 100644 --- a/test/unit/org/apache/cassandra/net/MockMessagingSpy.java +++ b/test/unit/org/apache/cassandra/net/MockMessagingSpy.java @@ -43,25 +43,25 @@ public class MockMessagingSpy public int messagesIntercepted = 0; public int mockedMessageResponses = 0; - private final BlockingQueue> interceptedMessages = new LinkedBlockingQueue<>(); - private final BlockingQueue> deliveredResponses = new LinkedBlockingQueue<>(); + private final BlockingQueue> interceptedMessages = new LinkedBlockingQueue<>(); + private final BlockingQueue> deliveredResponses = new LinkedBlockingQueue<>(); private static final Executor executor = Executors.newSingleThreadExecutor(); /** * Returns a future with the first mocked incoming message that has been created and delivered. */ - public ListenableFuture> captureMockedMessageIn() + public ListenableFuture> captureMockedMessage() { - return Futures.transform(captureMockedMessageInN(1), (List> result) -> result.isEmpty() ? null : result.get(0)); + return Futures.transform(captureMockedMessageN(1), (List> result) -> result.isEmpty() ? null : result.get(0)); } /** * Returns a future with the specified number mocked incoming messages that have been created and delivered. */ - public ListenableFuture>> captureMockedMessageInN(int noOfMessages) + public ListenableFuture>> captureMockedMessageN(int noOfMessages) { - CapturedResultsFuture> ret = new CapturedResultsFuture<>(noOfMessages, deliveredResponses); + CapturedResultsFuture> ret = new CapturedResultsFuture<>(noOfMessages, deliveredResponses); executor.execute(ret); return ret; } @@ -69,17 +69,17 @@ public ListenableFuture>> captureMockedMessageInN(int noOfMess /** * Returns a future that will indicate if a mocked incoming message has been created and delivered. */ - public ListenableFuture expectMockedMessageIn() + public ListenableFuture expectMockedMessage() { - return expectMockedMessageIn(1); + return expectMockedMessage(1); } /** * Returns a future that will indicate if the specified number of mocked incoming message have been created and delivered. */ - public ListenableFuture expectMockedMessageIn(int noOfMessages) + public ListenableFuture expectMockedMessage(int noOfMessages) { - ResultsCompletionFuture> ret = new ResultsCompletionFuture<>(noOfMessages, deliveredResponses); + ResultsCompletionFuture> ret = new ResultsCompletionFuture<>(noOfMessages, deliveredResponses); executor.execute(ret); return ret; } @@ -87,17 +87,17 @@ public ListenableFuture expectMockedMessageIn(int noOfMessages) /** * Returns a future with the first intercepted outbound message that would have been send. */ - public ListenableFuture> captureMessageOut() + public ListenableFuture> captureMessageOut() { - return Futures.transform(captureMessageOut(1), (List> result) -> result.isEmpty() ? null : result.get(0)); + return Futures.transform(captureMessageOut(1), (List> result) -> result.isEmpty() ? null : result.get(0)); } /** * Returns a future with the specified number of intercepted outbound messages that would have been send. */ - public ListenableFuture>> captureMessageOut(int noOfMessages) + public ListenableFuture>> captureMessageOut(int noOfMessages) { - CapturedResultsFuture> ret = new CapturedResultsFuture<>(noOfMessages, interceptedMessages); + CapturedResultsFuture> ret = new CapturedResultsFuture<>(noOfMessages, interceptedMessages); executor.execute(ret); return ret; } @@ -115,7 +115,7 @@ public ListenableFuture interceptMessageOut() */ public ListenableFuture interceptMessageOut(int noOfMessages) { - ResultsCompletionFuture> ret = new ResultsCompletionFuture<>(noOfMessages, interceptedMessages); + ResultsCompletionFuture> ret = new ResultsCompletionFuture<>(noOfMessages, interceptedMessages); executor.execute(ret); return ret; } @@ -125,19 +125,19 @@ public ListenableFuture interceptMessageOut(int noOfMessages) */ public ListenableFuture interceptNoMsg(long time, TimeUnit unit) { - ResultAbsenceFuture> ret = new ResultAbsenceFuture<>(interceptedMessages, time, unit); + ResultAbsenceFuture> ret = new ResultAbsenceFuture<>(interceptedMessages, time, unit); executor.execute(ret); return ret; } - void matchingMessage(MessageOut message) + void matchingMessage(Message message) { messagesIntercepted++; logger.trace("Received matching message: {}", message); interceptedMessages.add(message); } - void matchingResponse(MessageIn response) + void matchingResponse(Message response) { mockedMessageResponses++; logger.trace("Responding to intercepted message: {}", response); diff --git a/test/unit/org/apache/cassandra/net/OutboundConnectionSettingsTest.java b/test/unit/org/apache/cassandra/net/OutboundConnectionSettingsTest.java new file mode 100644 index 000000000000..7cf78a77a91d --- /dev/null +++ b/test/unit/org/apache/cassandra/net/OutboundConnectionSettingsTest.java @@ -0,0 +1,145 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.net; + +import java.util.HashMap; +import java.util.Map; +import java.util.function.Function; + +import org.junit.Assert; +import org.junit.BeforeClass; +import org.junit.Test; + +import org.apache.cassandra.config.Config; +import org.apache.cassandra.config.DatabaseDescriptor; +import org.apache.cassandra.locator.AbstractEndpointSnitch; +import org.apache.cassandra.locator.InetAddressAndPort; +import org.apache.cassandra.locator.Replica; + +import static org.apache.cassandra.config.DatabaseDescriptor.getEndpointSnitch; +import static org.apache.cassandra.net.MessagingService.current_version; +import static org.apache.cassandra.net.ConnectionType.*; +import static org.apache.cassandra.net.OutboundConnectionsTest.LOCAL_ADDR; +import static org.apache.cassandra.net.OutboundConnectionsTest.REMOTE_ADDR; + +public class OutboundConnectionSettingsTest +{ + @BeforeClass + public static void before() + { + DatabaseDescriptor.daemonInitialization(); + } + + @Test (expected = IllegalArgumentException.class) + public void build_SmallSendSize() + { + test(settings -> settings.withSocketSendBufferSizeInBytes(999)); + } + + @Test (expected = IllegalArgumentException.class) + public void build_SendSizeLessThanZero() + { + test(settings -> settings.withSocketSendBufferSizeInBytes(-1)); + } + + @Test (expected = IllegalArgumentException.class) + public void build_TcpConnectTimeoutLessThanZero() + { + test(settings -> settings.withTcpConnectTimeoutInMS(-1)); + } + + @Test(expected = IllegalArgumentException.class) + public void build_TcpUserTimeoutLessThanZero() + { + test(settings -> settings.withTcpUserTimeoutInMS(-1)); + } + + @Test + public void build_TcpUserTimeoutEqualsZero() + { + test(settings -> settings.withTcpUserTimeoutInMS(0)); + } + + private static void test(Function f) + { + f.apply(new OutboundConnectionSettings(LOCAL_ADDR)).withDefaults(ConnectionCategory.MESSAGING); + } + + private static class TestSnitch extends AbstractEndpointSnitch + { + private final Map nodeToDc = new HashMap<>(); + + void add(InetAddressAndPort node, String dc) + { + nodeToDc.put(node, dc); + } + + public String getRack(InetAddressAndPort endpoint) + { + return null; + } + + public String getDatacenter(InetAddressAndPort endpoint) + { + return nodeToDc.get(endpoint); + } + + public int compareEndpoints(InetAddressAndPort target, Replica a1, Replica a2) + { + return 0; + } + } + + @Test + public void shouldCompressConnection_None() + { + DatabaseDescriptor.setInternodeCompression(Config.InternodeCompression.none); + Assert.assertFalse(OutboundConnectionSettings.shouldCompressConnection(getEndpointSnitch(), LOCAL_ADDR, REMOTE_ADDR)); + } + + @Test + public void shouldCompressConnection_DifferentDc() + { + TestSnitch snitch = new TestSnitch(); + snitch.add(LOCAL_ADDR, "dc1"); + snitch.add(REMOTE_ADDR, "dc2"); + DatabaseDescriptor.setEndpointSnitch(snitch); + DatabaseDescriptor.setInternodeCompression(Config.InternodeCompression.dc); + Assert.assertTrue(OutboundConnectionSettings.shouldCompressConnection(getEndpointSnitch(), LOCAL_ADDR, REMOTE_ADDR)); + } + + @Test + public void shouldCompressConnection_All() + { + DatabaseDescriptor.setInternodeCompression(Config.InternodeCompression.all); + Assert.assertTrue(OutboundConnectionSettings.shouldCompressConnection(getEndpointSnitch(), LOCAL_ADDR, REMOTE_ADDR)); + } + + @Test + public void shouldCompressConnection_SameDc() + { + TestSnitch snitch = new TestSnitch(); + snitch.add(LOCAL_ADDR, "dc1"); + snitch.add(REMOTE_ADDR, "dc1"); + DatabaseDescriptor.setEndpointSnitch(snitch); + DatabaseDescriptor.setInternodeCompression(Config.InternodeCompression.dc); + Assert.assertFalse(OutboundConnectionSettings.shouldCompressConnection(getEndpointSnitch(), LOCAL_ADDR, REMOTE_ADDR)); + } + +} diff --git a/test/unit/org/apache/cassandra/net/OutboundConnectionsTest.java b/test/unit/org/apache/cassandra/net/OutboundConnectionsTest.java new file mode 100644 index 000000000000..20180fb9b4e6 --- /dev/null +++ b/test/unit/org/apache/cassandra/net/OutboundConnectionsTest.java @@ -0,0 +1,162 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.net; + +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +import com.google.common.collect.ImmutableList; +import com.google.common.net.InetAddresses; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; + +import org.apache.cassandra.config.DatabaseDescriptor; +import org.apache.cassandra.gms.GossipDigestSyn; +import org.apache.cassandra.io.IVersionedSerializer; +import org.apache.cassandra.io.util.DataInputPlus; +import org.apache.cassandra.io.util.DataOutputPlus; +import org.apache.cassandra.locator.InetAddressAndPort; +import org.apache.cassandra.net.BackPressureState; +import org.apache.cassandra.net.ConnectionType; +import org.apache.cassandra.net.Message; +import org.apache.cassandra.net.OutboundConnectionSettings; +import org.apache.cassandra.net.OutboundConnections; +import org.apache.cassandra.net.PingRequest; +import org.apache.cassandra.net.Verb; + +public class OutboundConnectionsTest +{ + static final InetAddressAndPort LOCAL_ADDR = InetAddressAndPort.getByAddressOverrideDefaults(InetAddresses.forString("127.0.0.1"), 9476); + static final InetAddressAndPort REMOTE_ADDR = InetAddressAndPort.getByAddressOverrideDefaults(InetAddresses.forString("127.0.0.2"), 9476); + private static final InetAddressAndPort RECONNECT_ADDR = InetAddressAndPort.getByAddressOverrideDefaults(InetAddresses.forString("127.0.0.3"), 9476); + private static final List INTERNODE_MESSAGING_CONN_TYPES = ImmutableList.of(ConnectionType.URGENT_MESSAGES, ConnectionType.LARGE_MESSAGES, ConnectionType.SMALL_MESSAGES); + + private OutboundConnections connections; + + @BeforeClass + public static void before() + { + DatabaseDescriptor.daemonInitialization(); + } + + @Before + public void setup() + { + BackPressureState backPressureState = DatabaseDescriptor.getBackPressureStrategy().newState(REMOTE_ADDR); + connections = OutboundConnections.unsafeCreate(new OutboundConnectionSettings(REMOTE_ADDR), backPressureState); + } + + @After + public void tearDown() throws ExecutionException, InterruptedException, TimeoutException + { + if (connections != null) + connections.close(false).get(10L, TimeUnit.SECONDS); + } + + @Test + public void getConnection_Gossip() + { + GossipDigestSyn syn = new GossipDigestSyn("cluster", "partitioner", new ArrayList<>(0)); + Message message = Message.out(Verb.GOSSIP_DIGEST_SYN, syn); + Assert.assertEquals(ConnectionType.URGENT_MESSAGES, connections.connectionFor(message).type()); + } + + @Test + public void getConnection_SmallMessage() + { + Message message = Message.out(Verb.PING_REQ, PingRequest.forSmall); + Assert.assertEquals(ConnectionType.SMALL_MESSAGES, connections.connectionFor(message).type()); + } + + @Test + public void getConnection_LargeMessage() throws NoSuchFieldException, IllegalAccessException + { + // just need a serializer to report a size, as fake as it may be + IVersionedSerializer serializer = new IVersionedSerializer() + { + public void serialize(Object o, DataOutputPlus out, int version) + { + + } + + public Object deserialize(DataInputPlus in, int version) + { + return null; + } + + public long serializedSize(Object o, int version) + { + return OutboundConnections.LARGE_MESSAGE_THRESHOLD + 1; + } + }; + Verb._TEST_2.unsafeSetSerializer(() -> serializer); + Message message = Message.out(Verb._TEST_2, "payload"); + Assert.assertEquals(ConnectionType.LARGE_MESSAGES, connections.connectionFor(message).type()); + } + + @Test + public void close_SoftClose() throws ExecutionException, InterruptedException, TimeoutException + { + for (ConnectionType type : INTERNODE_MESSAGING_CONN_TYPES) + Assert.assertFalse(connections.connectionFor(type).isClosed()); + connections.close(true).get(10L, TimeUnit.SECONDS); + for (ConnectionType type : INTERNODE_MESSAGING_CONN_TYPES) + Assert.assertTrue(connections.connectionFor(type).isClosed()); + } + + @Test + public void close_NotSoftClose() throws ExecutionException, InterruptedException, TimeoutException + { + for (ConnectionType type : INTERNODE_MESSAGING_CONN_TYPES) + Assert.assertFalse(connections.connectionFor(type).isClosed()); + connections.close(false).get(10L, TimeUnit.SECONDS); + for (ConnectionType type : INTERNODE_MESSAGING_CONN_TYPES) + Assert.assertTrue(connections.connectionFor(type).isClosed()); + } + + @Test + public void reconnectWithNewIp() throws InterruptedException + { + for (ConnectionType type : INTERNODE_MESSAGING_CONN_TYPES) + { + Assert.assertEquals(REMOTE_ADDR, connections.connectionFor(type).settings().connectTo); + } + + connections.reconnectWithNewIp(RECONNECT_ADDR).await(); + + for (ConnectionType type : INTERNODE_MESSAGING_CONN_TYPES) + { + Assert.assertEquals(RECONNECT_ADDR, connections.connectionFor(type).settings().connectTo); + } + } + +// @Test +// public void timeoutCounter() +// { +// long originalValue = connections.getTimeouts(); +// connections.incrementTimeout(); +// Assert.assertEquals(originalValue + 1, connections.getTimeouts()); +// } +} diff --git a/test/unit/org/apache/cassandra/net/OutboundMessageQueueTest.java b/test/unit/org/apache/cassandra/net/OutboundMessageQueueTest.java new file mode 100644 index 000000000000..db571ac2e7bf --- /dev/null +++ b/test/unit/org/apache/cassandra/net/OutboundMessageQueueTest.java @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.net; + +import java.util.concurrent.CountDownLatch; + +import com.google.common.util.concurrent.Uninterruptibles; + +import org.junit.Assert; +import org.junit.BeforeClass; +import org.junit.Test; + +import org.apache.cassandra.config.DatabaseDescriptor; +import org.apache.cassandra.net.Message; +import org.apache.cassandra.net.OutboundMessageQueue; +import org.apache.cassandra.net.Verb; + +import static org.apache.cassandra.net.NoPayload.noPayload; + +// TODO: incomplete +public class OutboundMessageQueueTest +{ + + @BeforeClass + public static void init() + { + DatabaseDescriptor.daemonInitialization(); + } + + @Test + public void testRemove() throws InterruptedException + { + final Message m1 = Message.out(Verb._TEST_1, noPayload); + final Message m2 = Message.out(Verb._TEST_1, noPayload); + final Message m3 = Message.out(Verb._TEST_1, noPayload); + + final OutboundMessageQueue queue = new OutboundMessageQueue(message -> true); + queue.add(m1); + queue.add(m2); + queue.add(m3); + + Assert.assertTrue(queue.remove(m1)); + Assert.assertFalse(queue.remove(m1)); + + CountDownLatch locked = new CountDownLatch(1); + CountDownLatch lockUntil = new CountDownLatch(1); + new Thread(() -> { + try (OutboundMessageQueue.WithLock lock = queue.lockOrCallback(0, () -> {})) + { + locked.countDown(); + Uninterruptibles.awaitUninterruptibly(lockUntil); + } + }).start(); + Uninterruptibles.awaitUninterruptibly(locked); + + CountDownLatch start = new CountDownLatch(2); + CountDownLatch finish = new CountDownLatch(2); + new Thread(() -> { + start.countDown(); + Assert.assertTrue(queue.remove(m2)); + finish.countDown(); + }).start(); + new Thread(() -> { + start.countDown(); + Assert.assertTrue(queue.remove(m3)); + finish.countDown(); + }).start(); + Uninterruptibles.awaitUninterruptibly(start); + lockUntil.countDown(); + Uninterruptibles.awaitUninterruptibly(finish); + + try (OutboundMessageQueue.WithLock lock = queue.lockOrCallback(0, () -> {})) + { + Assert.assertNull(lock.peek()); + } + } + +} diff --git a/test/unit/org/apache/cassandra/net/ProxyHandlerConnectionsTest.java b/test/unit/org/apache/cassandra/net/ProxyHandlerConnectionsTest.java new file mode 100644 index 000000000000..270a910720af --- /dev/null +++ b/test/unit/org/apache/cassandra/net/ProxyHandlerConnectionsTest.java @@ -0,0 +1,405 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.net; + +import java.io.IOException; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Supplier; +import java.util.function.ToLongFunction; + +import com.google.common.util.concurrent.Uninterruptibles; +import org.junit.After; +import org.junit.Assert; +import org.junit.BeforeClass; +import org.junit.Test; + +import io.netty.buffer.ByteBuf; +import org.apache.cassandra.config.DatabaseDescriptor; +import org.apache.cassandra.io.IVersionedAsymmetricSerializer; +import org.apache.cassandra.io.IVersionedSerializer; +import org.apache.cassandra.io.util.DataInputPlus; +import org.apache.cassandra.io.util.DataOutputPlus; +import org.apache.cassandra.locator.InetAddressAndPort; +import org.apache.cassandra.net.proxy.InboundProxyHandler; +import org.apache.cassandra.utils.FBUtilities; +import org.apache.cassandra.utils.Pair; + +import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static java.util.concurrent.TimeUnit.SECONDS; +import static org.apache.cassandra.net.ConnectionTest.SETTINGS; +import static org.apache.cassandra.net.OutboundConnectionSettings.Framing.CRC; +import static org.apache.cassandra.utils.MonotonicClock.approxTime; + +public class ProxyHandlerConnectionsTest +{ + private static final SocketFactory factory = new SocketFactory(); + + private final Map>> serializers = new HashMap<>(); + private final Map>> handlers = new HashMap<>(); + private final Map> timeouts = new HashMap<>(); + + private void unsafeSetSerializer(Verb verb, Supplier> supplier) throws Throwable + { + serializers.putIfAbsent(verb, verb.unsafeSetSerializer(supplier)); + } + + protected void unsafeSetHandler(Verb verb, Supplier> supplier) throws Throwable + { + handlers.putIfAbsent(verb, verb.unsafeSetHandler(supplier)); + } + + private void unsafeSetExpiration(Verb verb, ToLongFunction expiration) throws Throwable + { + timeouts.putIfAbsent(verb, verb.unsafeSetExpiration(expiration)); + } + + @BeforeClass + public static void startup() + { + DatabaseDescriptor.daemonInitialization(); + } + + @After + public void cleanup() throws Throwable + { + for (Map.Entry>> e : serializers.entrySet()) + e.getKey().unsafeSetSerializer(e.getValue()); + serializers.clear(); + for (Map.Entry>> e : handlers.entrySet()) + e.getKey().unsafeSetHandler(e.getValue()); + handlers.clear(); + for (Map.Entry> e : timeouts.entrySet()) + e.getKey().unsafeSetExpiration(e.getValue()); + timeouts.clear(); + } + + @Test + public void testExpireInbound() throws Throwable + { + DatabaseDescriptor.setCrossNodeTimeout(true); + testOneManual((settings, inbound, outbound, endpoint, handler) -> { + unsafeSetSerializer(Verb._TEST_1, FakePayloadSerializer::new); + + CountDownLatch connectionLatch = new CountDownLatch(1); + unsafeSetHandler(Verb._TEST_1, () -> v -> { + connectionLatch.countDown(); + }); + outbound.enqueue(Message.out(Verb._TEST_1, 1L)); + connectionLatch.await(10, SECONDS); + Assert.assertEquals(0, connectionLatch.getCount()); + + // Slow things down + unsafeSetExpiration(Verb._TEST_1, unit -> unit.convert(50, MILLISECONDS)); + handler.withLatency(100, MILLISECONDS); + + unsafeSetHandler(Verb._TEST_1, () -> v -> { + throw new RuntimeException("Should have not been triggered " + v); + }); + int expireMessages = 10; + for (int i = 0; i < expireMessages; i++) + outbound.enqueue(Message.out(Verb._TEST_1, 1L)); + + InboundMessageHandlers handlers = MessagingService.instance().getInbound(endpoint); + waitForCondition(() -> handlers.expiredCount() == expireMessages); + Assert.assertEquals(expireMessages, handlers.expiredCount()); + }); + } + + @Test + public void testExpireSome() throws Throwable + { + DatabaseDescriptor.setCrossNodeTimeout(true); + testOneManual((settings, inbound, outbound, endpoint, handler) -> { + unsafeSetSerializer(Verb._TEST_1, FakePayloadSerializer::new); + connect(outbound); + + AtomicInteger counter = new AtomicInteger(); + unsafeSetHandler(Verb._TEST_1, () -> v -> { + counter.incrementAndGet(); + }); + + int expireMessages = 10; + for (int i = 0; i < expireMessages; i++) + outbound.enqueue(Message.out(Verb._TEST_1, 1L)); + waitForCondition(() -> counter.get() == 10); + + unsafeSetExpiration(Verb._TEST_1, unit -> unit.convert(50, MILLISECONDS)); + handler.withLatency(100, MILLISECONDS); + + InboundMessageHandlers handlers = MessagingService.instance().getInbound(endpoint); + for (int i = 0; i < expireMessages; i++) + outbound.enqueue(Message.out(Verb._TEST_1, 1L)); + waitForCondition(() -> handlers.expiredCount() == 10); + + handler.withLatency(2, MILLISECONDS); + + for (int i = 0; i < expireMessages; i++) + outbound.enqueue(Message.out(Verb._TEST_1, 1L)); + waitForCondition(() -> counter.get() == 20); + }); + } + + @Test + public void testExpireSomeFromBatch() throws Throwable + { + DatabaseDescriptor.setCrossNodeTimeout(true); + testManual((settings, inbound, outbound, endpoint, handler) -> { + unsafeSetSerializer(Verb._TEST_1, FakePayloadSerializer::new); + connect(outbound); + + Message msg = Message.out(Verb._TEST_1, 1L); + int messageSize = msg.serializedSize(MessagingService.current_version); + DatabaseDescriptor.setInternodeMaxMessageSizeInBytes(messageSize * 40); + + AtomicInteger counter = new AtomicInteger(); + unsafeSetHandler(Verb._TEST_1, () -> v -> { + counter.incrementAndGet(); + }); + + unsafeSetExpiration(Verb._TEST_1, unit -> unit.convert(200, MILLISECONDS)); + handler.withLatency(100, MILLISECONDS); + + int expireMessages = 20; + long nanoTime = approxTime.now(); + CountDownLatch enqueueDone = new CountDownLatch(1); + outbound.unsafeRunOnDelivery(() -> Uninterruptibles.awaitUninterruptibly(enqueueDone, 10, SECONDS)); + for (int i = 0; i < expireMessages; i++) + { + boolean expire = i % 2 == 0; + Message.Builder builder = Message.builder(Verb._TEST_1, 1L); + + if (settings.right.acceptVersions == ConnectionTest.legacy) + { + // backdate messages; leave 50 milliseconds to leave outbound path + builder.withCreatedAt(nanoTime - (expire ? 0 : MILLISECONDS.toNanos(150))); + } + else + { + // Give messages 50 milliseconds to leave outbound path + builder.withCreatedAt(nanoTime) + .withExpiresAt(nanoTime + (expire ? MILLISECONDS.toNanos(50) : MILLISECONDS.toNanos(1000))); + } + outbound.enqueue(builder.build()); + } + enqueueDone.countDown(); + + InboundMessageHandlers handlers = MessagingService.instance().getInbound(endpoint); + waitForCondition(() -> handlers.expiredCount() == 10 && counter.get() == 10, + () -> String.format("Expired: %d, Arrived: %d", handlers.expiredCount(), counter.get())); + }); + } + + @Test + public void suddenDisconnect() throws Throwable + { + testManual((settings, inbound, outbound, endpoint, handler) -> { + handler.onDisconnect(() -> handler.reset()); + + unsafeSetSerializer(Verb._TEST_1, FakePayloadSerializer::new); + connect(outbound); + + CountDownLatch closeLatch = new CountDownLatch(1); + handler.withCloseAfterRead(closeLatch::countDown); + AtomicInteger counter = new AtomicInteger(); + unsafeSetHandler(Verb._TEST_1, () -> v -> counter.incrementAndGet()); + + outbound.enqueue(Message.out(Verb._TEST_1, 1L)); + waitForCondition(() -> !outbound.isConnected()); + + connect(outbound); + Assert.assertTrue(outbound.isConnected()); + Assert.assertEquals(0, counter.get()); + }); + } + + @Test + public void testCorruptionOnHandshake() throws Throwable + { + testManual((settings, inbound, outbound, endpoint, handler) -> { + unsafeSetSerializer(Verb._TEST_1, FakePayloadSerializer::new); + // Invalid CRC + handler.withPayloadTransform(msg -> { + ByteBuf bb = (ByteBuf) msg; + bb.setByte(bb.readableBytes() / 2, 0xffff); + return msg; + }); + tryConnect(outbound, 1, SECONDS, false); + Assert.assertTrue(!outbound.isConnected()); + + // Invalid protocol magic + handler.withPayloadTransform(msg -> { + ByteBuf bb = (ByteBuf) msg; + bb.setByte(0, 0xffff); + return msg; + }); + tryConnect(outbound, 1, SECONDS, false); + Assert.assertTrue(!outbound.isConnected()); + if (settings.right.framing == CRC) + { + Assert.assertEquals(2, outbound.connectionAttempts()); + Assert.assertEquals(0, outbound.successfulConnections()); + } + }); + } + + private static void waitForCondition(Supplier cond) throws Throwable + { + CompletableFuture.runAsync(() -> { + while (!cond.get()) {} + }).get(10, SECONDS); + } + + private static void waitForCondition(Supplier cond, Supplier s) throws Throwable + { + try + { + CompletableFuture.runAsync(() -> { + while (!cond.get()) {} + }).get(10, SECONDS); + } + catch (TimeoutException e) + { + throw new AssertionError(s.get()); + } + } + + private static class FakePayloadSerializer implements IVersionedSerializer + { + private final int size; + private FakePayloadSerializer() + { + this(1); + } + + // Takes long and repeats it size times + private FakePayloadSerializer(int size) + { + this.size = size; + } + + public void serialize(Long i, DataOutputPlus out, int version) throws IOException + { + for (int j = 0; j < size; j++) + { + out.writeLong(i); + } + } + + public Long deserialize(DataInputPlus in, int version) throws IOException + { + long l = in.readLong(); + for (int i = 0; i < size - 1; i++) + { + if (in.readLong() != l) + throw new AssertionError(); + } + + return l; + } + + public long serializedSize(Long t, int version) + { + return Long.BYTES * size; + } + } + interface ManualSendTest + { + void accept(Pair settings, InboundSockets inbound, OutboundConnection outbound, InetAddressAndPort endpoint, InboundProxyHandler.Controller handler) throws Throwable; + } + + private void testManual(ManualSendTest test) throws Throwable + { + for (ConnectionTest.Settings s: SETTINGS) + { + doTestManual(s, test); + cleanup(); + } + } + + private void testOneManual(ManualSendTest test) throws Throwable + { + testOneManual(test, 1); + } + + private void testOneManual(ManualSendTest test, int i) throws Throwable + { + ConnectionTest.Settings s = SETTINGS.get(i); + doTestManual(s, test); + cleanup(); + } + + private void doTestManual(ConnectionTest.Settings settings, ManualSendTest test) throws Throwable + { + InetAddressAndPort endpoint = FBUtilities.getBroadcastAddressAndPort(); + + InboundConnectionSettings inboundSettings = settings.inbound.apply(new InboundConnectionSettings()) + .withBindAddress(endpoint) + .withSocketFactory(factory); + + InboundSockets inbound = new InboundSockets(Collections.singletonList(inboundSettings)); + + OutboundConnectionSettings outboundSettings = settings.outbound.apply(new OutboundConnectionSettings(endpoint)) + .withConnectTo(endpoint) + .withDefaultReserveLimits() + .withSocketFactory(factory); + + ResourceLimits.EndpointAndGlobal reserveCapacityInBytes = new ResourceLimits.EndpointAndGlobal(new ResourceLimits.Concurrent(outboundSettings.applicationSendQueueReserveEndpointCapacityInBytes), outboundSettings.applicationSendQueueReserveGlobalCapacityInBytes); + OutboundConnection outbound = new OutboundConnection(settings.type, outboundSettings, reserveCapacityInBytes); + try + { + InboundProxyHandler.Controller controller = new InboundProxyHandler.Controller(); + inbound.open(pipeline -> { + InboundProxyHandler handler = new InboundProxyHandler(controller); + pipeline.addLast(handler); + }).sync(); + test.accept(Pair.create(inboundSettings, outboundSettings), inbound, outbound, endpoint, controller); + } + finally + { + outbound.close(false); + inbound.close().get(30L, SECONDS); + outbound.close(false).get(30L, SECONDS); + MessagingService.instance().messageHandlers.clear(); + } + } + + private void connect(OutboundConnection outbound) throws Throwable + { + tryConnect(outbound, 10, SECONDS, true); + } + + private void tryConnect(OutboundConnection outbound, long timeout, TimeUnit timeUnit, boolean throwOnFailure) throws Throwable + { + CountDownLatch connectionLatch = new CountDownLatch(1); + unsafeSetHandler(Verb._TEST_1, () -> v -> { + connectionLatch.countDown(); + }); + outbound.enqueue(Message.out(Verb._TEST_1, 1L)); + connectionLatch.await(timeout, timeUnit); + if (throwOnFailure) + Assert.assertEquals(0, connectionLatch.getCount()); + } +} diff --git a/test/unit/org/apache/cassandra/net/PrunableArrayQueueTest.java b/test/unit/org/apache/cassandra/net/PrunableArrayQueueTest.java new file mode 100644 index 000000000000..c4fd55a8aa27 --- /dev/null +++ b/test/unit/org/apache/cassandra/net/PrunableArrayQueueTest.java @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.cassandra.net; + +import org.junit.Test; + +import org.apache.cassandra.net.PrunableArrayQueue; + +import static org.junit.Assert.*; + +public class PrunableArrayQueueTest +{ + private final PrunableArrayQueue queue = new PrunableArrayQueue<>(8); + + @Test + public void testIsEmptyWhenEmpty() + { + assertTrue(queue.isEmpty()); + } + + @Test + public void testIsEmptyWhenNotEmpty() + { + queue.offer(0); + assertFalse(queue.isEmpty()); + } + + @Test + public void testEmptyPeek() + { + assertNull(queue.peek()); + } + + @Test + public void testNonEmptyPeek() + { + queue.offer(0); + assertEquals((Integer) 0, queue.peek()); + } + + @Test + public void testEmptyPoll() + { + assertNull(queue.poll()); + } + + @Test + public void testNonEmptyPoll() + { + queue.offer(0); + assertEquals((Integer) 0, queue.poll()); + } + + @Test + public void testTransfersInCorrectOrder() + { + for (int i = 0; i < 1024; i++) + queue.offer(i); + + for (int i = 0; i < 1024; i++) + assertEquals((Integer) i, queue.poll()); + + assertTrue(queue.isEmpty()); + } + + @Test + public void testTransfersInCorrectOrderWhenInterleaved() + { + for (int i = 0; i < 1024; i++) + { + queue.offer(i); + assertEquals((Integer) i, queue.poll()); + } + + assertTrue(queue.isEmpty()); + } + + @Test + public void testPrune() + { + for (int i = 0; i < 1024; i++) + queue.offer(i); + + class Pruner implements PrunableArrayQueue.Pruner + { + private int pruned, kept; + + public boolean shouldPrune(Integer val) + { + return val % 2 == 0; + } + + public void onPruned(Integer val) + { + pruned++; + } + + public void onKept(Integer val) + { + kept++; + } + } + + Pruner pruner = new Pruner(); + assertEquals(512, queue.prune(pruner)); + + assertEquals(512, pruner.kept); + assertEquals(512, pruner.pruned); + assertEquals(512, queue.size()); + + for (int i = 1; i < 1024; i += 2) + assertEquals((Integer) i, queue.poll()); + assertTrue(queue.isEmpty()); + } +} \ No newline at end of file diff --git a/test/unit/org/apache/cassandra/net/ResourceLimitsTest.java b/test/unit/org/apache/cassandra/net/ResourceLimitsTest.java new file mode 100644 index 000000000000..734d69afe776 --- /dev/null +++ b/test/unit/org/apache/cassandra/net/ResourceLimitsTest.java @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.cassandra.net; + +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.Executor; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.function.LongFunction; + +import org.junit.Test; + +import org.apache.cassandra.net.ResourceLimits.*; + +import static org.junit.Assert.*; + +public class ResourceLimitsTest +{ + @Test + public void testAllocatesWithinLimits() + { + testAllocatesWithinLimits(Basic::new); + testAllocatesWithinLimits(Concurrent::new); + } + + private void testAllocatesWithinLimits(LongFunction supplier) + { + Limit limit = supplier.apply(100); + + assertEquals(100, limit.limit()); + assertEquals(0, limit.using()); + assertEquals(100, limit.remaining()); + + assertTrue(limit.tryAllocate(10)); + assertEquals(10, limit.using()); + assertEquals(90, limit.remaining()); + + assertTrue(limit.tryAllocate(30)); + assertEquals(40, limit.using()); + assertEquals(60, limit.remaining()); + + assertTrue(limit.tryAllocate(60)); + assertEquals(100, limit.using()); + assertEquals(0, limit.remaining()); + } + + @Test + public void testFailsToAllocateOverCapacity() + { + testFailsToAllocateOverCapacity(Basic::new); + testFailsToAllocateOverCapacity(Concurrent::new); + } + + private void testFailsToAllocateOverCapacity(LongFunction supplier) + { + Limit limit = supplier.apply(100); + + assertEquals(100, limit.limit()); + assertEquals(0, limit.using()); + assertEquals(100, limit.remaining()); + + assertTrue(limit.tryAllocate(10)); + assertEquals(10, limit.using()); + assertEquals(90, limit.remaining()); + + assertFalse(limit.tryAllocate(91)); + assertEquals(10, limit.using()); + assertEquals(90, limit.remaining()); + } + + @Test + public void testRelease() + { + testRelease(Basic::new); + testRelease(Concurrent::new); + } + + private void testRelease(LongFunction supplier) + { + Limit limit = supplier.apply(100); + + assertEquals(100, limit.limit()); + assertEquals(0, limit.using()); + assertEquals(100, limit.remaining()); + + assertTrue(limit.tryAllocate(10)); + assertTrue(limit.tryAllocate(30)); + assertTrue(limit.tryAllocate(60)); + assertEquals(100, limit.using()); + assertEquals(0, limit.remaining()); + + limit.release(10); + assertEquals(90, limit.using()); + assertEquals(10, limit.remaining()); + + limit.release(30); + assertEquals(60, limit.using()); + assertEquals(40, limit.remaining()); + + limit.release(60); + assertEquals(0, limit.using()); + assertEquals(100, limit.remaining()); + } + + @Test + public void testConcurrentLimit() throws Exception + { + int numThreads = 4; + int numPermitsPerThread = 1_000_000; + int numPermits = numThreads * numPermitsPerThread; + + CountDownLatch latch = new CountDownLatch(numThreads); + Limit limit = new Concurrent(numPermits); + + class Worker implements Runnable + { + public void run() + { + for (int i = 0; i < numPermitsPerThread; i += 10) + assertTrue(limit.tryAllocate(10)); + + for (int i = 0; i < numPermitsPerThread; i += 10) + limit.release(10); + + latch.countDown(); + } + } + + Executor executor = Executors.newFixedThreadPool(numThreads); + for (int i = 0; i < numThreads; i++) + executor.execute(new Worker()); + latch.await(10, TimeUnit.SECONDS); + + assertEquals(0, limit.using()); + assertEquals(numPermits, limit.remaining()); + } +} \ No newline at end of file diff --git a/test/unit/org/apache/cassandra/net/SocketUtils.java b/test/unit/org/apache/cassandra/net/SocketUtils.java new file mode 100644 index 000000000000..a0a149029b24 --- /dev/null +++ b/test/unit/org/apache/cassandra/net/SocketUtils.java @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.net; + +import java.io.IOException; +import java.net.ServerSocket; + +import com.google.common.base.Throwables; + +public class SocketUtils +{ + public static synchronized int findAvailablePort() throws RuntimeException + { + ServerSocket ss = null; + try + { + // let the system pick an ephemeral port + ss = new ServerSocket(0); + ss.setReuseAddress(true); + return ss.getLocalPort(); + } + catch (IOException e) + { + throw Throwables.propagate(e); + } + finally + { + if (ss != null) + { + try + { + ss.close(); + } + catch (IOException e) + { + Throwables.propagate(e); + } + } + } + } +} \ No newline at end of file diff --git a/test/unit/org/apache/cassandra/net/StartupClusterConnectivityCheckerTest.java b/test/unit/org/apache/cassandra/net/StartupClusterConnectivityCheckerTest.java index af72456ec444..0785f277daba 100644 --- a/test/unit/org/apache/cassandra/net/StartupClusterConnectivityCheckerTest.java +++ b/test/unit/org/apache/cassandra/net/StartupClusterConnectivityCheckerTest.java @@ -19,11 +19,11 @@ package org.apache.cassandra.net; import java.net.UnknownHostException; -import java.util.Collections; import java.util.HashMap; import java.util.HashSet; import java.util.Map; import java.util.Set; +import java.util.function.BiPredicate; import org.junit.After; import org.junit.Assert; @@ -38,8 +38,6 @@ import org.apache.cassandra.locator.InetAddressAndPort; import org.apache.cassandra.utils.FBUtilities; -import static org.apache.cassandra.net.async.OutboundConnectionIdentifier.ConnectionType.SMALL_MESSAGE; - public class StartupClusterConnectivityCheckerTest { private StartupClusterConnectivityChecker localQuorumConnectivityChecker; @@ -108,32 +106,30 @@ public void setUp() throws UnknownHostException @After public void tearDown() { - MessagingService.instance().clearMessageSinks(); + MessagingService.instance().outboundSink.clear(); } @Test public void execute_HappyPath() { Sink sink = new Sink(true, true, peers); - MessagingService.instance().addMessageSink(sink); + MessagingService.instance().outboundSink.add(sink); Assert.assertTrue(localQuorumConnectivityChecker.execute(peers, this::getDatacenter)); - Assert.assertTrue(checkAllConnectionTypesSeen(sink)); } @Test public void execute_NotAlive() { Sink sink = new Sink(false, true, peers); - MessagingService.instance().addMessageSink(sink); + MessagingService.instance().outboundSink.add(sink); Assert.assertFalse(localQuorumConnectivityChecker.execute(peers, this::getDatacenter)); - Assert.assertTrue(checkAllConnectionTypesSeen(sink)); } @Test public void execute_NoConnectionsAcks() { Sink sink = new Sink(true, false, peers); - MessagingService.instance().addMessageSink(sink); + MessagingService.instance().outboundSink.add(sink); Assert.assertFalse(localQuorumConnectivityChecker.execute(peers, this::getDatacenter)); } @@ -143,12 +139,12 @@ public void execute_LocalQuorum() // local peer plus 3 peers from same dc shouldn't pass (4/6) Set available = new HashSet<>(); copyCount(peersAMinusLocal, available, NUM_PER_DC - 3); - checkAvailable(localQuorumConnectivityChecker, available, false, true); + checkAvailable(localQuorumConnectivityChecker, available, false); // local peer plus 4 peers from same dc should pass (5/6) available.clear(); copyCount(peersAMinusLocal, available, NUM_PER_DC - 2); - checkAvailable(localQuorumConnectivityChecker, available, true, true); + checkAvailable(localQuorumConnectivityChecker, available, true); } @Test @@ -159,56 +155,45 @@ public void execute_GlobalQuorum() copyCount(peersAMinusLocal, available, NUM_PER_DC - 2); copyCount(peersB, available, NUM_PER_DC - 2); copyCount(peersC, available, NUM_PER_DC - 1); - checkAvailable(globalQuorumConnectivityChecker, available, false, true); + checkAvailable(globalQuorumConnectivityChecker, available, false); // All three datacenters should be able to have a single node down available.clear(); copyCount(peersAMinusLocal, available, NUM_PER_DC - 2); copyCount(peersB, available, NUM_PER_DC - 1); copyCount(peersC, available, NUM_PER_DC - 1); - checkAvailable(globalQuorumConnectivityChecker, available, true, true); + checkAvailable(globalQuorumConnectivityChecker, available, true); // Everything being up should work of course available.clear(); copyCount(peersAMinusLocal, available, NUM_PER_DC - 1); copyCount(peersB, available, NUM_PER_DC); copyCount(peersC, available, NUM_PER_DC); - checkAvailable(globalQuorumConnectivityChecker, available, true, true); + checkAvailable(globalQuorumConnectivityChecker, available, true); } @Test public void execute_Noop() { - checkAvailable(noopChecker, new HashSet<>(), true, false); + checkAvailable(noopChecker, new HashSet<>(), true); } @Test public void execute_ZeroWaitHasConnections() throws InterruptedException { Sink sink = new Sink(true, true, new HashSet<>()); - MessagingService.instance().addMessageSink(sink); + MessagingService.instance().outboundSink.add(sink); Assert.assertFalse(zeroWaitChecker.execute(peers, this::getDatacenter)); - boolean hasConnections = false; - for (int i = 0; i < TIMEOUT_NANOS; i+= 10) - { - hasConnections = checkAllConnectionTypesSeen(sink); - if (hasConnections) - break; - Thread.sleep(0, 10); - } - MessagingService.instance().clearMessageSinks(); - Assert.assertTrue(hasConnections); + MessagingService.instance().outboundSink.clear(); } private void checkAvailable(StartupClusterConnectivityChecker checker, Set available, - boolean shouldPass, boolean checkConnections) + boolean shouldPass) { Sink sink = new Sink(true, true, available); - MessagingService.instance().addMessageSink(sink); + MessagingService.instance().outboundSink.add(sink); Assert.assertEquals(shouldPass, checker.execute(peers, this::getDatacenter)); - if (checkConnections) - Assert.assertTrue(checkAllConnectionTypesSeen(sink)); - MessagingService.instance().clearMessageSinks(); + MessagingService.instance().outboundSink.clear(); } private void copyCount(Set source, Set dest, int count) @@ -223,25 +208,7 @@ private void copyCount(Set source, Set d } } - private boolean checkAllConnectionTypesSeen(Sink sink) - { - boolean result = true; - for (InetAddressAndPort peer : peers) - { - if (peer.equals(FBUtilities.getBroadcastAddressAndPort())) - continue; - ConnectionTypeRecorder recorder = sink.seenConnectionRequests.get(peer); - result = recorder != null; - if (!result) - break; - - result = recorder.seenSmallMessageRequest; - result &= recorder.seenLargeMessageRequest; - } - return result; - } - - private static class Sink implements IMessageSink + private static class Sink implements BiPredicate, InetAddressAndPort> { private final boolean markAliveInGossip; private final boolean processConnectAck; @@ -257,39 +224,25 @@ private static class Sink implements IMessageSink } @Override - public boolean allowOutgoingMessage(MessageOut message, int id, InetAddressAndPort to) + public boolean test(Message message, InetAddressAndPort to) { ConnectionTypeRecorder recorder = seenConnectionRequests.computeIfAbsent(to, inetAddress -> new ConnectionTypeRecorder()); - if (message.connectionType == SMALL_MESSAGE) - { - Assert.assertFalse(recorder.seenSmallMessageRequest); - recorder.seenSmallMessageRequest = true; - } - else - { - Assert.assertFalse(recorder.seenLargeMessageRequest); - recorder.seenLargeMessageRequest = true; - } if (!aliveHosts.contains(to)) return false; if (processConnectAck) { - MessageIn msgIn = MessageIn.create(to, message.payload, Collections.emptyMap(), MessagingService.Verb.REQUEST_RESPONSE, 1); - MessagingService.instance().getRegisteredCallback(id).callback.response(msgIn); + Message msgIn = Message.builder(Verb.REQUEST_RSP, message.payload) + .from(to) + .build(); + MessagingService.instance().callbacks.get(message.id(), to).callback.onResponse(msgIn); } if (markAliveInGossip) Gossiper.runInGossipStageBlocking(() -> Gossiper.instance.realMarkAlive(to, new EndpointState(new HeartBeatState(1, 1)))); return false; } - - @Override - public boolean allowIncomingMessage(MessageIn message, int id) - { - return false; - } } private static class ConnectionTypeRecorder diff --git a/test/unit/org/apache/cassandra/net/TestAbstractAsyncPromise.java b/test/unit/org/apache/cassandra/net/TestAbstractAsyncPromise.java new file mode 100644 index 000000000000..fd61b093cbfa --- /dev/null +++ b/test/unit/org/apache/cassandra/net/TestAbstractAsyncPromise.java @@ -0,0 +1,234 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.net; + +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.CancellationException; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +import com.google.common.util.concurrent.Uninterruptibles; +import org.junit.Assert; + +import io.netty.util.concurrent.Future; +import io.netty.util.concurrent.GenericFutureListener; +import io.netty.util.concurrent.Promise; + +abstract class TestAbstractAsyncPromise extends TestAbstractPromise +{ + void testOneSuccess(Promise promise, boolean setUncancellable, boolean tryOrSet, V value, V otherValue) + { + List results = new ArrayList<>(); + List order = new ArrayList<>(); + class ListenerFactory + { + int count = 0; + + public GenericFutureListener> get() + { + int id = count++; + return p -> { results.add(p.getNow()); order.add(id); }; + } + public GenericFutureListener> getRecursive() + { + int id = count++; + return p -> { promise.addListener(get()); results.add(p.getNow()); order.add(id); }; + } + } + ListenerFactory listeners = new ListenerFactory(); + Async async = new Async(); + promise.addListener(listeners.get()); + promise.addListeners(listeners.getRecursive(), listeners.get()); + promise.addListener(listeners.getRecursive()); + success(promise, Promise::getNow, null); + success(promise, Promise::isSuccess, false); + success(promise, Promise::isDone, false); + success(promise, Promise::isCancelled, false); + success(promise, Promise::isCancellable, true); + if (setUncancellable) + { + success(promise, Promise::setUncancellable, true); + success(promise, Promise::setUncancellable, true); + success(promise, p -> p.cancel(true), false); + success(promise, p -> p.cancel(false), false); + } + success(promise, Promise::isCancellable, !setUncancellable); + async.success(promise, Promise::get, value); + async.success(promise, p -> p.get(1L, TimeUnit.SECONDS), value); + async.success(promise, Promise::await, promise); + async.success(promise, Promise::awaitUninterruptibly, promise); + async.success(promise, p -> p.await(1L, TimeUnit.SECONDS), true); + async.success(promise, p -> p.await(1000L), true); + async.success(promise, p -> p.awaitUninterruptibly(1L, TimeUnit.SECONDS), true); + async.success(promise, p -> p.awaitUninterruptibly(1000L), true); + async.success(promise, Promise::sync, promise); + async.success(promise, Promise::syncUninterruptibly, promise); + if (tryOrSet) promise.trySuccess(value); + else promise.setSuccess(value); + success(promise, p -> p.cancel(true), false); + success(promise, p -> p.cancel(false), false); + failure(promise, p -> p.setSuccess(null), IllegalStateException.class); + failure(promise, p -> p.setFailure(new NullPointerException()), IllegalStateException.class); + success(promise, Promise::getNow, value); + success(promise, p -> p.trySuccess(otherValue), false); + success(promise, p -> p.tryFailure(new NullPointerException()), false); + success(promise, Promise::getNow, value); + success(promise, Promise::cause, null); + promise.addListener(listeners.get()); + promise.addListeners(listeners.getRecursive(), listeners.get()); + promise.addListener(listeners.getRecursive()); + success(promise, Promise::isSuccess, true); + success(promise, Promise::isDone, true); + success(promise, Promise::isCancelled, false); + success(promise, Promise::isCancellable, false); + async.verify(); + Assert.assertEquals(listeners.count, results.size()); + Assert.assertEquals(listeners.count, order.size()); + for (V result : results) + Assert.assertEquals(value, result); + for (int i = 0 ; i < order.size() ; ++i) + Assert.assertEquals(i, order.get(i).intValue()); + } + + void testOneFailure(Promise promise, boolean setUncancellable, boolean tryOrSet, Throwable cause, V otherValue) + { + List results = new ArrayList<>(); + List order = new ArrayList<>(); + Async async = new Async(); + class ListenerFactory + { + int count = 0; + + public GenericFutureListener> get() + { + int id = count++; + return p -> { results.add(p.cause()); order.add(id); }; + } + public GenericFutureListener> getRecursive() + { + int id = count++; + return p -> { promise.addListener(get()); results.add(p.cause()); order.add(id); }; + } + } + ListenerFactory listeners = new ListenerFactory(); + promise.addListener(listeners.get()); + promise.addListeners(listeners.getRecursive(), listeners.get()); + promise.addListener(listeners.getRecursive()); + success(promise, Promise::isSuccess, false); + success(promise, Promise::isDone, false); + success(promise, Promise::isCancelled, false); + success(promise, Promise::isCancellable, true); + if (setUncancellable) + { + success(promise, Promise::setUncancellable, true); + success(promise, Promise::setUncancellable, true); + success(promise, p -> p.cancel(true), false); + success(promise, p -> p.cancel(false), false); + } + success(promise, Promise::isCancellable, !setUncancellable); + success(promise, Promise::getNow, null); + success(promise, Promise::cause, null); + async.failure(promise, Promise::get, ExecutionException.class); + async.failure(promise, p -> p.get(1L, TimeUnit.SECONDS), ExecutionException.class); + async.success(promise, Promise::await, promise); + async.success(promise, Promise::awaitUninterruptibly, promise); + async.success(promise, p -> p.await(1L, TimeUnit.SECONDS), true); + async.success(promise, p -> p.await(1000L), true); + async.success(promise, p -> p.awaitUninterruptibly(1L, TimeUnit.SECONDS), true); + async.success(promise, p -> p.awaitUninterruptibly(1000L), true); + async.failure(promise, Promise::sync, cause); + async.failure(promise, Promise::syncUninterruptibly, cause); + if (tryOrSet) promise.tryFailure(cause); + else promise.setFailure(cause); + success(promise, p -> p.cancel(true), false); + success(promise, p -> p.cancel(false), false); + failure(promise, p -> p.setSuccess(null), IllegalStateException.class); + failure(promise, p -> p.setFailure(new NullPointerException()), IllegalStateException.class); + success(promise, Promise::cause, cause); + success(promise, Promise::getNow, null); + success(promise, p -> p.trySuccess(otherValue), false); + success(promise, p -> p.tryFailure(new NullPointerException()), false); + success(promise, Promise::getNow, null); + success(promise, Promise::cause, cause); + promise.addListener(listeners.get()); + promise.addListeners(listeners.getRecursive(), listeners.get()); + promise.addListener(listeners.getRecursive()); + success(promise, Promise::isSuccess, false); + success(promise, Promise::isDone, true); + success(promise, Promise::isCancelled, false); + success(promise, Promise::isCancellable, false); + async.verify(); + Assert.assertEquals(listeners.count, results.size()); + Assert.assertEquals(listeners.count, order.size()); + for (Throwable result : results) + Assert.assertEquals(cause, result); + for (int i = 0 ; i < order.size() ; ++i) + Assert.assertEquals(i, order.get(i).intValue()); + } + + public void testOneCancellation(Promise promise, boolean interruptIfRunning, V otherValue) + { + Async async = new Async(); + success(promise, Promise::isCancellable, true); + success(promise, Promise::getNow, null); + success(promise, Promise::cause, null); + async.failure(promise, Promise::get, CancellationException.class); + async.failure(promise, p -> p.get(1L, TimeUnit.SECONDS), CancellationException.class); + async.success(promise, Promise::await, promise); + async.success(promise, Promise::awaitUninterruptibly, promise); + async.success(promise, p -> p.await(1L, TimeUnit.SECONDS), true); + async.success(promise, p -> p.await(1000L), true); + async.success(promise, p -> p.awaitUninterruptibly(1L, TimeUnit.SECONDS), true); + async.success(promise, p -> p.awaitUninterruptibly(1000L), true); + async.failure(promise, Promise::sync, CancellationException.class); + async.failure(promise, Promise::syncUninterruptibly, CancellationException.class); + promise.cancel(interruptIfRunning); + failure(promise, p -> p.setFailure(null), IllegalStateException.class); + failure(promise, p -> p.setFailure(null), IllegalStateException.class); + Assert.assertTrue(promise.cause() instanceof CancellationException); + success(promise, Promise::getNow, null); + success(promise, p -> p.trySuccess(otherValue), false); + success(promise, Promise::getNow, null); + Assert.assertTrue(promise.cause() instanceof CancellationException); + success(promise, Promise::isSuccess, false); + success(promise, Promise::isDone, true); + success(promise, Promise::isCancelled, true); + success(promise, Promise::isCancellable, false); + async.verify(); + } + + + public void testOneTimeout(Promise promise, boolean setUncancellable) + { + Async async = new Async(); + if (setUncancellable) + success(promise, Promise::setUncancellable, true); + success(promise, Promise::isCancellable, !setUncancellable); + async.failure(promise, p -> p.get(1L, TimeUnit.MILLISECONDS), TimeoutException.class); + async.success(promise, p -> p.await(1L, TimeUnit.MILLISECONDS), false); + async.success(promise, p -> p.await(1L), false); + async.success(promise, p -> p.awaitUninterruptibly(1L, TimeUnit.MILLISECONDS), false); + async.success(promise, p -> p.awaitUninterruptibly(1L), false); + Uninterruptibles.sleepUninterruptibly(10L, TimeUnit.MILLISECONDS); + async.verify(); + } + +} diff --git a/test/unit/org/apache/cassandra/net/TestAbstractPromise.java b/test/unit/org/apache/cassandra/net/TestAbstractPromise.java new file mode 100644 index 000000000000..963c61fe35ee --- /dev/null +++ b/test/unit/org/apache/cassandra/net/TestAbstractPromise.java @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.net; + +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.function.Predicate; + +import org.junit.Assert; + +import io.netty.util.concurrent.Promise; +import net.openhft.chronicle.core.util.ThrowingBiConsumer; +import net.openhft.chronicle.core.util.ThrowingConsumer; +import net.openhft.chronicle.core.util.ThrowingFunction; + +abstract class TestAbstractPromise +{ + final ExecutorService exec = Executors.newCachedThreadPool(); + + class Async + { + final List> waitingOn = new ArrayList<>(); + void verify() + { + for (int i = 0 ; i < waitingOn.size() ; ++i) + { + try + { + waitingOn.get(i).accept(100L, TimeUnit.MILLISECONDS); + } + catch (Throwable t) + { + throw new AssertionError("" + i, t); + } + } + } + void failure(Promise promise, ThrowingConsumer, ?> action, Throwable failsWith) + { + waitingOn.add(exec.submit(() -> TestAbstractPromise.failure(promise, action, failsWith))::get); + } + void failure(Promise promise, ThrowingConsumer, ?> action, Class failsWith) + { + waitingOn.add(exec.submit(() -> TestAbstractPromise.failure(promise, action, failsWith))::get); + } + void failure(Promise promise, ThrowingConsumer, ?> action, Predicate failsWith) + { + waitingOn.add(exec.submit(() -> TestAbstractPromise.failure(promise, action, failsWith))::get); + } +

, R> void success(P promise, ThrowingFunction action, R result) + { + waitingOn.add(exec.submit(() -> TestAbstractPromise.success(promise, action, result))::get); + } + } + + private static void failure(Promise promise, ThrowingConsumer, ?> action, Throwable failsWith) + { + failure(promise, action, t -> Objects.equals(failsWith, t)); + } + + static void failure(Promise promise, ThrowingConsumer, ?> action, Class failsWith) + { + failure(promise, action, failsWith::isInstance); + } + + private static void failure(Promise promise, ThrowingConsumer, ?> action, Predicate failsWith) + { + Throwable fail = null; + try + { + action.accept(promise); + } + catch (Throwable t) + { + fail = t; + } + if (!failsWith.test(fail)) + throw new AssertionError(fail); + } + + static

, R> void success(P promise, ThrowingFunction action, R result) + { + try + { + Assert.assertEquals(result, action.apply(promise)); + } + catch (Throwable t) + { + throw new AssertionError(t); + } + } + +} diff --git a/test/unit/org/apache/cassandra/net/TestChannel.java b/test/unit/org/apache/cassandra/net/TestChannel.java new file mode 100644 index 000000000000..feddab0c93e0 --- /dev/null +++ b/test/unit/org/apache/cassandra/net/TestChannel.java @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.net; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.WritableByteChannel; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelOutboundBuffer; +import io.netty.channel.FileRegion; +import io.netty.channel.embedded.EmbeddedChannel; +import org.apache.cassandra.net.FrameEncoder; +import org.apache.cassandra.net.GlobalBufferPoolAllocator; + +public class TestChannel extends EmbeddedChannel +{ + final int inFlightLimit; + int inFlight; + + ChannelOutboundBuffer flush; + long flushBytes; + + public TestChannel(int inFlightLimit) + { + this.inFlightLimit = inFlightLimit; + } + + // we override ByteBuf to prevent retain() from working, to avoid release() since it is not needed in our usage + // since the lifetime must live longer, we simply copy any outbound ByteBuf here for our tests + protected void doWrite(ChannelOutboundBuffer in) + { + assert flush == null || flush == in; + doWrite(in, in.totalPendingWriteBytes()); + } + + private void doWrite(ChannelOutboundBuffer flush, long flushBytes) + { + while (true) { + Object msg = flush.current(); + if (msg == null) { + this.flush = null; + this.flushBytes = 0; + return; + } + + if (inFlight >= inFlightLimit) + { + this.flush = flush; + this.flushBytes = flushBytes; + return; + } + + ByteBuf buf; + if (msg instanceof FileRegion) + { + buf = GlobalBufferPoolAllocator.instance.directBuffer((int) ((FileRegion) msg).count()); + try + { + ((FileRegion) msg).transferTo(new WritableByteChannel() + { + public int write(ByteBuffer src) + { + buf.setBytes(0, src); + return buf.writerIndex(); + } + + public boolean isOpen() { return false; } + + public void close() { } + }, 0); + } + catch (IOException e) + { + throw new RuntimeException(e); + } + } + else if (msg instanceof ByteBuf) + { + buf = ((ByteBuf)msg).copy(); + } + else if (msg instanceof FrameEncoder.Payload) + { + buf = Unpooled.wrappedBuffer(((FrameEncoder.Payload)msg).buffer).copy(); + } + else + { + System.err.println("Unexpected message type " + msg); + throw new IllegalArgumentException(); + } + + inFlight += buf.readableBytes(); + handleOutboundMessage(buf); + flush.remove(); + } + } + + public T readOutbound() + { + T msg = super.readOutbound(); + if (msg instanceof ByteBuf) + { + inFlight -= ((ByteBuf) msg).readableBytes(); + if (flush != null && inFlight < inFlightLimit) + doWrite(flush, flushBytes); + } + return msg; + } +} + diff --git a/test/unit/org/apache/cassandra/net/async/TestScheduledFuture.java b/test/unit/org/apache/cassandra/net/TestScheduledFuture.java similarity index 97% rename from test/unit/org/apache/cassandra/net/async/TestScheduledFuture.java rename to test/unit/org/apache/cassandra/net/TestScheduledFuture.java index f5475ce2b0c3..456f8c4a852b 100644 --- a/test/unit/org/apache/cassandra/net/async/TestScheduledFuture.java +++ b/test/unit/org/apache/cassandra/net/TestScheduledFuture.java @@ -16,7 +16,7 @@ * limitations under the License. */ -package org.apache.cassandra.net.async; +package org.apache.cassandra.net; import java.util.concurrent.Delayed; import java.util.concurrent.ExecutionException; diff --git a/test/unit/org/apache/cassandra/net/WriteCallbackInfoTest.java b/test/unit/org/apache/cassandra/net/WriteCallbackInfoTest.java index e226d32e75b0..b4bf8b7d62ef 100644 --- a/test/unit/org/apache/cassandra/net/WriteCallbackInfoTest.java +++ b/test/unit/org/apache/cassandra/net/WriteCallbackInfoTest.java @@ -30,8 +30,6 @@ import org.apache.cassandra.db.RegularAndStaticColumns; import org.apache.cassandra.db.partitions.PartitionUpdate; import org.apache.cassandra.locator.InetAddressAndPort; -import org.apache.cassandra.locator.ReplicaUtils; -import org.apache.cassandra.net.MessagingService.Verb; import org.apache.cassandra.schema.MockSchema; import org.apache.cassandra.schema.TableMetadata; import org.apache.cassandra.service.paxos.Commit; @@ -50,8 +48,8 @@ public static void initDD() @Test public void testShouldHint() throws Exception { - testShouldHint(Verb.COUNTER_MUTATION, ConsistencyLevel.ALL, true, false); - for (Verb verb : new Verb[] { Verb.PAXOS_COMMIT, Verb.MUTATION }) + testShouldHint(Verb.COUNTER_MUTATION_REQ, ConsistencyLevel.ALL, true, false); + for (Verb verb : new Verb[] { Verb.PAXOS_COMMIT_REQ, Verb.MUTATION_REQ }) { testShouldHint(verb, ConsistencyLevel.ALL, true, true); testShouldHint(verb, ConsistencyLevel.ANY, true, false); @@ -62,11 +60,11 @@ public void testShouldHint() throws Exception private void testShouldHint(Verb verb, ConsistencyLevel cl, boolean allowHints, boolean expectHint) throws Exception { TableMetadata metadata = MockSchema.newTableMetadata("", ""); - Object payload = verb == Verb.PAXOS_COMMIT + Object payload = verb == Verb.PAXOS_COMMIT_REQ ? new Commit(UUID.randomUUID(), new PartitionUpdate.Builder(metadata, ByteBufferUtil.EMPTY_BYTE_BUFFER, RegularAndStaticColumns.NONE, 1).build()) : new Mutation(PartitionUpdate.simpleBuilder(metadata, "").build()); - WriteCallbackInfo wcbi = new WriteCallbackInfo(full(InetAddressAndPort.getByName("192.168.1.1")), null, new MessageOut(verb, payload, null), null, cl, allowHints); + RequestCallbacks.WriteCallbackInfo wcbi = new RequestCallbacks.WriteCallbackInfo(Message.out(verb, payload), full(InetAddressAndPort.getByName("192.168.1.1")), null, cl, allowHints); Assert.assertEquals(expectHint, wcbi.shouldHint()); if (expectHint) { diff --git a/test/unit/org/apache/cassandra/net/async/ByteBufDataOutputPlusTest.java b/test/unit/org/apache/cassandra/net/async/ByteBufDataOutputPlusTest.java deleted file mode 100644 index 959c37aa2528..000000000000 --- a/test/unit/org/apache/cassandra/net/async/ByteBufDataOutputPlusTest.java +++ /dev/null @@ -1,178 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.cassandra.net.async; - -import java.io.IOException; -import java.nio.ByteBuffer; - -import org.junit.After; -import org.junit.Assert; -import org.junit.BeforeClass; -import org.junit.Test; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.PooledByteBufAllocator; -import io.netty.buffer.Unpooled; -import org.apache.cassandra.SchemaLoader; -import org.apache.cassandra.config.DatabaseDescriptor; -import org.apache.cassandra.db.ColumnFamilyStore; -import org.apache.cassandra.db.Keyspace; -import org.apache.cassandra.db.Mutation; -import org.apache.cassandra.db.RowUpdateBuilder; -import org.apache.cassandra.db.compaction.CompactionManager; -import org.apache.cassandra.db.marshal.AsciiType; -import org.apache.cassandra.db.marshal.BytesType; -import org.apache.cassandra.io.util.DataOutputBuffer; -import org.apache.cassandra.io.util.DataOutputPlus; -import org.apache.cassandra.io.util.Memory; -import org.apache.cassandra.io.util.SafeMemory; -import org.apache.cassandra.net.MessagingService; -import org.apache.cassandra.schema.KeyspaceParams; - -public class ByteBufDataOutputPlusTest -{ - private static final String KEYSPACE1 = "NettyPipilineTest"; - private static final String STANDARD1 = "Standard1"; - private static final int columnCount = 128; - - private ByteBuf buf; - - @BeforeClass - public static void before() - { - DatabaseDescriptor.daemonInitialization(); - SchemaLoader.prepareServer(); - SchemaLoader.createKeyspace(KEYSPACE1, - KeyspaceParams.simple(1), - SchemaLoader.standardCFMD(KEYSPACE1, STANDARD1, columnCount, AsciiType.instance, BytesType.instance)); - CompactionManager.instance.disableAutoCompaction(); - } - - @After - public void tearDown() - { - if (buf != null) - buf.release(); - } - - @Test - public void compareBufferSizes() throws IOException - { - final int currentFrameSize = getMessage().message.serializedSize(MessagingService.current_version); - - ByteBuffer buffer = ByteBuffer.allocateDirect(currentFrameSize); //bufferedOut.nioBuffer(0, bufferedOut.writableBytes()); - getMessage().message.serialize(new DataOutputBuffer(buffer), MessagingService.current_version); - Assert.assertFalse(buffer.hasRemaining()); - Assert.assertEquals(buffer.capacity(), buffer.position()); - - ByteBuf bbosOut = PooledByteBufAllocator.DEFAULT.ioBuffer(currentFrameSize, currentFrameSize); - try - { - getMessage().message.serialize(new ByteBufDataOutputPlus(bbosOut), MessagingService.current_version); - - Assert.assertFalse(bbosOut.isWritable()); - Assert.assertEquals(bbosOut.capacity(), bbosOut.writerIndex()); - - Assert.assertEquals(buffer.position(), bbosOut.writerIndex()); - for (int i = 0; i < currentFrameSize; i++) - { - Assert.assertEquals(buffer.get(i), bbosOut.getByte(i)); - } - } - finally - { - bbosOut.release(); - } - } - - private QueuedMessage getMessage() - { - ColumnFamilyStore cfs1 = Keyspace.open(KEYSPACE1).getColumnFamilyStore(STANDARD1); - ByteBuffer buf = ByteBuffer.allocate(1 << 10); - RowUpdateBuilder rowUpdateBuilder = new RowUpdateBuilder(cfs1.metadata.get(), 0, "k") - .clustering("bytes"); - for (int i = 0; i < columnCount; i++) - rowUpdateBuilder.add("val" + i, buf); - - Mutation mutation = rowUpdateBuilder.build(); - return new QueuedMessage(mutation.createMessage(), 42); - } - - @Test - public void compareDOS() throws IOException - { - buf = PooledByteBufAllocator.DEFAULT.ioBuffer(1024, 1024); - ByteBuffer buffer = ByteBuffer.allocateDirect(1024); - - ByteBufDataOutputPlus byteBufDataOutputPlus = new ByteBufDataOutputPlus(buf); - DataOutputBuffer dataOutputBuffer = new DataOutputBuffer(buffer); - - write(byteBufDataOutputPlus); - write(dataOutputBuffer); - - Assert.assertEquals(buffer.position(), buf.writerIndex()); - for (int i = 0; i < buffer.position(); i++) - { - Assert.assertEquals(buffer.get(i), buf.getByte(i)); - } - } - - private void write(DataOutputPlus out) throws IOException - { - ByteBuffer b = ByteBuffer.allocate(8); - b.putLong(29811134237462734L); - out.write(b); - b = ByteBuffer.allocateDirect(8); - b.putDouble(92367.4253647890626); - out.write(b); - - out.writeInt(29319236); - - byte[] array = new byte[17]; - for (int i = 0; i < array.length; i++) - array[i] = (byte)i; - out.write(array, 0 , array.length); - - out.write(42); - out.writeUTF("This is a great string!!"); - out.writeByte(-100); - out.writeUnsignedVInt(3247634L); - out.writeVInt(12313695L); - out.writeBoolean(true); - out.writeShort(4371); - out.writeChar('j'); - out.writeLong(472348263487234L); - out.writeFloat(34534.12623F); - out.writeDouble(0.2384253D); - out.writeBytes("Write my bytes"); - out.writeChars("These are some swell chars"); - - Memory memory = new SafeMemory(8); - memory.setLong(0, -21365123651231L); - out.write(memory, 0, memory.size()); - memory.close(); - } - - @Test (expected = UnsupportedOperationException.class) - public void applyToChannel() throws IOException - { - ByteBufDataOutputPlus out = new ByteBufDataOutputPlus(Unpooled.wrappedBuffer(new byte[0])); - out.applyToChannel(null); - } -} diff --git a/test/unit/org/apache/cassandra/net/async/ChannelWriterTest.java b/test/unit/org/apache/cassandra/net/async/ChannelWriterTest.java deleted file mode 100644 index 02115124d4de..000000000000 --- a/test/unit/org/apache/cassandra/net/async/ChannelWriterTest.java +++ /dev/null @@ -1,314 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.cassandra.net.async; - -import java.io.IOException; -import java.net.InetSocketAddress; -import java.util.Optional; -import java.util.concurrent.BlockingQueue; -import java.util.concurrent.LinkedBlockingQueue; - -import com.google.common.net.InetAddresses; -import org.junit.Assert; -import org.junit.Before; -import org.junit.BeforeClass; -import org.junit.Test; - -import io.netty.buffer.ByteBuf; -import io.netty.channel.ChannelHandlerContext; -import io.netty.channel.ChannelOption; -import io.netty.channel.ChannelOutboundHandlerAdapter; -import io.netty.channel.ChannelPromise; -import io.netty.channel.WriteBufferWaterMark; -import io.netty.channel.embedded.EmbeddedChannel; -import org.apache.cassandra.config.DatabaseDescriptor; -import org.apache.cassandra.locator.InetAddressAndPort; -import org.apache.cassandra.net.MessageOut; -import org.apache.cassandra.net.MessagingService; -import org.apache.cassandra.net.async.ChannelWriter.CoalescingChannelWriter; -import org.apache.cassandra.utils.CoalescingStrategies; -import org.apache.cassandra.utils.CoalescingStrategies.CoalescingStrategy; - -import static org.apache.cassandra.net.MessagingService.Verb.ECHO; - -/** - * with the write_Coalescing_* methods, if there's data in the channel.unsafe().outboundBuffer() - * it means that there's something in the channel that hasn't yet been flushed to the transport (socket). - * once a flush occurs, there will be an entry in EmbeddedChannel's outboundQueue. those two facts are leveraged in these tests. - */ -public class ChannelWriterTest -{ - private static final int COALESCE_WINDOW_MS = 10; - - private EmbeddedChannel channel; - private ChannelWriter channelWriter; - private NonSendingOutboundMessagingConnection omc; - private Optional coalescingStrategy; - - @BeforeClass - public static void before() - { - DatabaseDescriptor.daemonInitialization(); - } - - @Before - public void setup() - { - OutboundConnectionIdentifier id = OutboundConnectionIdentifier.small(InetAddressAndPort.getByAddressOverrideDefaults(InetAddresses.forString("127.0.0.1"), 0), - InetAddressAndPort.getByAddressOverrideDefaults(InetAddresses.forString("127.0.0.2"), 0)); - channel = new EmbeddedChannel(); - omc = new NonSendingOutboundMessagingConnection(id, null, Optional.empty()); - channelWriter = ChannelWriter.create(channel, omc::handleMessageResult, Optional.empty()); - channel.pipeline().addFirst(new MessageOutHandler(id, MessagingService.current_version, channelWriter, () -> null)); - coalescingStrategy = CoalescingStrategies.newCoalescingStrategy(CoalescingStrategies.Strategy.FIXED.name(), COALESCE_WINDOW_MS, null, "test"); - } - - @Test - public void create_nonCoalescing() - { - Assert.assertSame(ChannelWriter.SimpleChannelWriter.class, ChannelWriter.create(channel, omc::handleMessageResult, Optional.empty()).getClass()); - } - - @Test - public void create_Coalescing() - { - Assert.assertSame(CoalescingChannelWriter.class, ChannelWriter.create(channel, omc::handleMessageResult, coalescingStrategy).getClass()); - } - - @Test - public void write_IsWritable() - { - Assert.assertTrue(channel.isWritable()); - Assert.assertTrue(channelWriter.write(new QueuedMessage(new MessageOut<>(ECHO), 42), true)); - Assert.assertTrue(channel.isWritable()); - Assert.assertTrue(channel.releaseOutbound()); - } - - @Test - public void write_NotWritable() - { - channel.config().setOption(ChannelOption.WRITE_BUFFER_WATER_MARK, new WriteBufferWaterMark(1, 2)); - - // send one message through, which will trigger the writability check (and turn it off) - Assert.assertTrue(channel.isWritable()); - ByteBuf buf = channel.alloc().buffer(8, 8); - channel.unsafe().outboundBuffer().addMessage(buf, buf.capacity(), channel.newPromise()); - Assert.assertFalse(channel.isWritable()); - Assert.assertFalse(channelWriter.write(new QueuedMessage(new MessageOut<>(ECHO), 42), true)); - Assert.assertFalse(channel.isWritable()); - Assert.assertFalse(channel.releaseOutbound()); - buf.release(); - } - - @Test - public void write_NotWritableButWriteAnyway() - { - channel.config().setOption(ChannelOption.WRITE_BUFFER_WATER_MARK, new WriteBufferWaterMark(1, 2)); - - // send one message through, which will trigger the writability check (and turn it off) - Assert.assertTrue(channel.isWritable()); - ByteBuf buf = channel.alloc().buffer(8, 8); - channel.unsafe().outboundBuffer().addMessage(buf, buf.capacity(), channel.newPromise()); - Assert.assertFalse(channel.isWritable()); - Assert.assertTrue(channelWriter.write(new QueuedMessage(new MessageOut<>(ECHO), 42), false)); - Assert.assertTrue(channel.isWritable()); - Assert.assertTrue(channel.releaseOutbound()); - } - - @Test - public void write_Coalescing_LostRaceForFlushTask() - { - CoalescingChannelWriter channelWriter = resetEnvForCoalescing(DatabaseDescriptor.getOtcCoalescingEnoughCoalescedMessages()); - channelWriter.scheduledFlush.set(true); - Assert.assertTrue(channel.unsafe().outboundBuffer().totalPendingWriteBytes() == 0); - Assert.assertTrue(channelWriter.write(new QueuedMessage(new MessageOut<>(ECHO), 42), true)); - Assert.assertTrue(channel.unsafe().outboundBuffer().totalPendingWriteBytes() > 0); - Assert.assertFalse(channel.releaseOutbound()); - Assert.assertTrue(channelWriter.scheduledFlush.get()); - } - - @Test - public void write_Coalescing_HitMinMessageCountForImmediateCoalesce() - { - CoalescingChannelWriter channelWriter = resetEnvForCoalescing(1); - - Assert.assertTrue(channel.unsafe().outboundBuffer().totalPendingWriteBytes() == 0); - Assert.assertFalse(channelWriter.scheduledFlush.get()); - Assert.assertTrue(channelWriter.write(new QueuedMessage(new MessageOut<>(ECHO), 42), true)); - - Assert.assertTrue(channel.unsafe().outboundBuffer().totalPendingWriteBytes() == 0); - Assert.assertTrue(channel.releaseOutbound()); - Assert.assertFalse(channelWriter.scheduledFlush.get()); - } - - @Test - public void write_Coalescing_ScheduleFlushTask() - { - CoalescingChannelWriter channelWriter = resetEnvForCoalescing(DatabaseDescriptor.getOtcCoalescingEnoughCoalescedMessages()); - - Assert.assertTrue(channel.unsafe().outboundBuffer().totalPendingWriteBytes() == 0); - Assert.assertFalse(channelWriter.scheduledFlush.get()); - Assert.assertTrue(channelWriter.write(new QueuedMessage(new MessageOut<>(ECHO), 42), true)); - - Assert.assertTrue(channelWriter.scheduledFlush.get()); - Assert.assertTrue(channel.unsafe().outboundBuffer().totalPendingWriteBytes() > 0); - Assert.assertTrue(channelWriter.scheduledFlush.get()); - - // this unfortunately know a little too much about how the sausage is made in CoalescingChannelWriter :-/ - channel.runScheduledPendingTasks(); - channel.runPendingTasks(); - Assert.assertTrue(channel.unsafe().outboundBuffer().totalPendingWriteBytes() == 0); - Assert.assertFalse(channelWriter.scheduledFlush.get()); - Assert.assertTrue(channel.releaseOutbound()); - } - - private CoalescingChannelWriter resetEnvForCoalescing(int minMessagesForCoalesce) - { - channel = new EmbeddedChannel(); - CoalescingChannelWriter cw = new CoalescingChannelWriter(channel, omc::handleMessageResult, coalescingStrategy.get(), minMessagesForCoalesce); - channel.pipeline().addFirst(new ChannelOutboundHandlerAdapter() - { - public void flush(ChannelHandlerContext ctx) throws Exception - { - cw.onTriggeredFlush(ctx); - } - }); - omc.setChannelWriter(cw); - return cw; - } - - @Test - public void writeBacklog_Empty() - { - BlockingQueue queue = new LinkedBlockingQueue<>(); - Assert.assertEquals(0, channelWriter.writeBacklog(queue, false)); - Assert.assertFalse(channel.releaseOutbound()); - } - - @Test - public void writeBacklog_ChannelNotWritable() - { - Assert.assertTrue(channel.isWritable()); - // force the channel to be non writable - channel.config().setOption(ChannelOption.WRITE_BUFFER_WATER_MARK, new WriteBufferWaterMark(1, 2)); - ByteBuf buf = channel.alloc().buffer(8, 8); - channel.unsafe().outboundBuffer().addMessage(buf, buf.capacity(), channel.newPromise()); - Assert.assertFalse(channel.isWritable()); - - Assert.assertEquals(0, channelWriter.writeBacklog(new LinkedBlockingQueue<>(), false)); - Assert.assertFalse(channel.releaseOutbound()); - Assert.assertFalse(channel.isWritable()); - buf.release(); - } - - @Test - public void writeBacklog_NotEmpty() - { - BlockingQueue queue = new LinkedBlockingQueue<>(); - int count = 12; - for (int i = 0; i < count; i++) - queue.offer(new QueuedMessage(new MessageOut<>(ECHO), i)); - Assert.assertEquals(count, channelWriter.writeBacklog(queue, false)); - Assert.assertTrue(channel.releaseOutbound()); - } - - @Test - public void close() - { - Assert.assertFalse(channelWriter.isClosed()); - Assert.assertTrue(channel.isOpen()); - channelWriter.close(); - Assert.assertFalse(channel.isOpen()); - Assert.assertTrue(channelWriter.isClosed()); - } - - @Test - public void softClose() - { - Assert.assertFalse(channelWriter.isClosed()); - Assert.assertTrue(channel.isOpen()); - channelWriter.softClose(); - Assert.assertFalse(channel.isOpen()); - Assert.assertTrue(channelWriter.isClosed()); - } - - @Test - public void handleMessagePromise_FutureIsCancelled() - { - ChannelPromise promise = channel.newPromise(); - promise.cancel(false); - channelWriter.handleMessageFuture(promise, new QueuedMessage(new MessageOut<>(ECHO), 1), true); - Assert.assertTrue(channel.isActive()); - Assert.assertEquals(1, omc.getCompletedMessages().longValue()); - Assert.assertEquals(0, omc.getDroppedMessages().longValue()); - } - - @Test - public void handleMessagePromise_ExpiredException_DoNotRetryMsg() - { - ChannelPromise promise = channel.newPromise(); - promise.setFailure(new ExpiredException()); - - channelWriter.handleMessageFuture(promise, new QueuedMessage(new MessageOut<>(ECHO), 1), true); - Assert.assertTrue(channel.isActive()); - Assert.assertEquals(1, omc.getCompletedMessages().longValue()); - Assert.assertEquals(1, omc.getDroppedMessages().longValue()); - Assert.assertFalse(omc.sendMessageInvoked); - } - - @Test - public void handleMessagePromise_NonIOException() - { - ChannelPromise promise = channel.newPromise(); - promise.setFailure(new NullPointerException("this is a test")); - channelWriter.handleMessageFuture(promise, new QueuedMessage(new MessageOut<>(ECHO), 1), true); - Assert.assertTrue(channel.isActive()); - Assert.assertEquals(1, omc.getCompletedMessages().longValue()); - Assert.assertEquals(0, omc.getDroppedMessages().longValue()); - Assert.assertFalse(omc.sendMessageInvoked); - } - - @Test - public void handleMessagePromise_IOException_ChannelNotClosed_RetryMsg() - { - ChannelPromise promise = channel.newPromise(); - promise.setFailure(new IOException("this is a test")); - Assert.assertTrue(channel.isActive()); - channelWriter.handleMessageFuture(promise, new QueuedMessage(new MessageOut<>(ECHO), 1, 0, true, true), true); - - Assert.assertFalse(channel.isActive()); - Assert.assertEquals(1, omc.getCompletedMessages().longValue()); - Assert.assertEquals(0, omc.getDroppedMessages().longValue()); - Assert.assertTrue(omc.sendMessageInvoked); - } - - @Test - public void handleMessagePromise_Cancelled() - { - ChannelPromise promise = channel.newPromise(); - promise.cancel(false); - Assert.assertTrue(channel.isActive()); - channelWriter.handleMessageFuture(promise, new QueuedMessage(new MessageOut<>(ECHO), 1, 0, true, true), true); - - Assert.assertTrue(channel.isActive()); - Assert.assertEquals(1, omc.getCompletedMessages().longValue()); - Assert.assertEquals(0, omc.getDroppedMessages().longValue()); - Assert.assertFalse(omc.sendMessageInvoked); - } -} diff --git a/test/unit/org/apache/cassandra/net/async/HandshakeHandlersTest.java b/test/unit/org/apache/cassandra/net/async/HandshakeHandlersTest.java deleted file mode 100644 index 087f49ed2201..000000000000 --- a/test/unit/org/apache/cassandra/net/async/HandshakeHandlersTest.java +++ /dev/null @@ -1,227 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.cassandra.net.async; - -import java.io.IOException; -import java.nio.ByteBuffer; -import java.util.Arrays; -import java.util.Optional; - -import com.google.common.net.InetAddresses; -import org.junit.Assert; -import org.junit.BeforeClass; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.junit.runners.Parameterized.Parameters; - -import io.netty.channel.embedded.EmbeddedChannel; -import org.apache.cassandra.SchemaLoader; -import org.apache.cassandra.auth.AllowAllInternodeAuthenticator; -import org.apache.cassandra.db.ColumnFamilyStore; -import org.apache.cassandra.db.Keyspace; -import org.apache.cassandra.db.Mutation; -import org.apache.cassandra.db.RowUpdateBuilder; -import org.apache.cassandra.db.compaction.CompactionManager; -import org.apache.cassandra.db.marshal.AsciiType; -import org.apache.cassandra.db.marshal.BytesType; -import org.apache.cassandra.exceptions.ConfigurationException; -import org.apache.cassandra.locator.InetAddressAndPort; -import org.apache.cassandra.net.MessageOut; -import org.apache.cassandra.net.MessagingService; -import org.apache.cassandra.schema.KeyspaceParams; - -import static org.apache.cassandra.net.async.InboundHandshakeHandler.State.HANDSHAKE_COMPLETE; -import static org.apache.cassandra.net.async.OutboundMessagingConnection.State.READY; - -@RunWith(Parameterized.class) -public class HandshakeHandlersTest -{ - private static final String KEYSPACE1 = "NettyPipilineTest"; - private static final String STANDARD1 = "Standard1"; - - private static final InetAddressAndPort LOCAL_ADDR = InetAddressAndPort.getByAddressOverrideDefaults(InetAddresses.forString("127.0.0.1"), 9999); - private static final InetAddressAndPort REMOTE_ADDR = InetAddressAndPort.getByAddressOverrideDefaults(InetAddresses.forString("127.0.0.2"), 9999); - private static final OutboundConnectionIdentifier connectionId = OutboundConnectionIdentifier.small(LOCAL_ADDR, REMOTE_ADDR); - private final int messagingVersion; - - @BeforeClass - public static void beforeClass() throws ConfigurationException - { - SchemaLoader.prepareServer(); - SchemaLoader.createKeyspace(KEYSPACE1, - KeyspaceParams.simple(1), - SchemaLoader.standardCFMD(KEYSPACE1, STANDARD1, 0, AsciiType.instance, BytesType.instance)); - CompactionManager.instance.disableAutoCompaction(); - } - - public HandshakeHandlersTest(int messagingVersion) - { - this.messagingVersion = messagingVersion; - } - - @Parameters() - public static Iterable generateData() - { - return Arrays.asList(MessagingService.VERSION_30, MessagingService.VERSION_40); - } - - @Test - public void handshake_HappyPath() - { - // beacuse both CHH & SHH are ChannelInboundHandlers, we can't use the same EmbeddedChannel to handle them - InboundHandshakeHandler inboundHandshakeHandler = new InboundHandshakeHandler(new TestAuthenticator(true)); - EmbeddedChannel inboundChannel = new EmbeddedChannel(inboundHandshakeHandler); - - OutboundMessagingConnection imc = new OutboundMessagingConnection(connectionId, null, Optional.empty(), new AllowAllInternodeAuthenticator()); - OutboundConnectionParams params = OutboundConnectionParams.builder() - .connectionId(connectionId) - .callback(imc::finishHandshake) - .mode(NettyFactory.Mode.MESSAGING) - .protocolVersion(MessagingService.current_version) - .coalescingStrategy(Optional.empty()) - .build(); - OutboundHandshakeHandler outboundHandshakeHandler = new OutboundHandshakeHandler(params); - EmbeddedChannel outboundChannel = new EmbeddedChannel(outboundHandshakeHandler); - Assert.assertEquals(1, outboundChannel.outboundMessages().size()); - - // move internode protocol Msg1 to the server's channel - Object o; - while ((o = outboundChannel.readOutbound()) != null) - inboundChannel.writeInbound(o); - Assert.assertEquals(1, inboundChannel.outboundMessages().size()); - - // move internode protocol Msg2 to the client's channel - while ((o = inboundChannel.readOutbound()) != null) - outboundChannel.writeInbound(o); - Assert.assertEquals(1, outboundChannel.outboundMessages().size()); - - // move internode protocol Msg3 to the server's channel - while ((o = outboundChannel.readOutbound()) != null) - inboundChannel.writeInbound(o); - - Assert.assertEquals(READY, imc.getState()); - Assert.assertEquals(HANDSHAKE_COMPLETE, inboundHandshakeHandler.getState()); - } - - @Test - public void lotsOfMutations_NoCompression() throws IOException - { - lotsOfMutations(false); - } - - @Test - public void lotsOfMutations_WithCompression() throws IOException - { - lotsOfMutations(true); - } - - private void lotsOfMutations(boolean compress) - { - TestChannels channels = buildChannels(compress); - EmbeddedChannel outboundChannel = channels.outboundChannel; - EmbeddedChannel inboundChannel = channels.inboundChannel; - - // now the actual test! - ByteBuffer buf = ByteBuffer.allocate(1 << 10); - byte[] bytes = "ThisIsA16CharStr".getBytes(); - while (buf.remaining() > 0) - buf.put(bytes); - - // write a bunch of messages to the channel - ColumnFamilyStore cfs1 = Keyspace.open(KEYSPACE1).getColumnFamilyStore(STANDARD1); - int count = 1024; - for (int i = 0; i < count; i++) - { - if (i % 2 == 0) - { - Mutation mutation = new RowUpdateBuilder(cfs1.metadata.get(), 0, "k") - .clustering("bytes") - .add("val", buf) - .build(); - - QueuedMessage msg = new QueuedMessage(mutation.createMessage(), i); - outboundChannel.writeAndFlush(msg); - } - else - { - outboundChannel.writeAndFlush(new QueuedMessage(new MessageOut<>(MessagingService.Verb.ECHO), i)); - } - } - outboundChannel.flush(); - - // move the messages to the other channel - Object o; - while ((o = outboundChannel.readOutbound()) != null) - inboundChannel.writeInbound(o); - - Assert.assertTrue(outboundChannel.outboundMessages().isEmpty()); - Assert.assertFalse(inboundChannel.finishAndReleaseAll()); - } - - private TestChannels buildChannels(boolean compress) - { - OutboundConnectionParams params = OutboundConnectionParams.builder() - .connectionId(connectionId) - .callback(this::nop) - .mode(NettyFactory.Mode.MESSAGING) - .compress(compress) - .coalescingStrategy(Optional.empty()) - .protocolVersion(MessagingService.current_version) - .backlogSupplier(this::nopBacklog) - .build(); - OutboundHandshakeHandler outboundHandshakeHandler = new OutboundHandshakeHandler(params); - EmbeddedChannel outboundChannel = new EmbeddedChannel(outboundHandshakeHandler); - OutboundMessagingConnection omc = new OutboundMessagingConnection(connectionId, null, Optional.empty(), new AllowAllInternodeAuthenticator()); - omc.setTargetVersion(messagingVersion); - outboundHandshakeHandler.setupPipeline(outboundChannel, messagingVersion); - - // remove the outbound handshake message from the outbound messages - outboundChannel.outboundMessages().clear(); - - InboundHandshakeHandler handler = new InboundHandshakeHandler(new TestAuthenticator(true)); - EmbeddedChannel inboundChannel = new EmbeddedChannel(handler); - handler.setupMessagingPipeline(inboundChannel.pipeline(), REMOTE_ADDR, compress, messagingVersion); - - return new TestChannels(outboundChannel, inboundChannel); - } - - private static class TestChannels - { - final EmbeddedChannel outboundChannel; - final EmbeddedChannel inboundChannel; - - TestChannels(EmbeddedChannel outboundChannel, EmbeddedChannel inboundChannel) - { - this.outboundChannel = outboundChannel; - this.inboundChannel = inboundChannel; - } - } - - private Void nop(OutboundHandshakeHandler.HandshakeResult handshakeResult) - { - // do nothing, really - return null; - } - - private QueuedMessage nopBacklog() - { - return null; - } -} diff --git a/test/unit/org/apache/cassandra/net/async/HandshakeProtocolTest.java b/test/unit/org/apache/cassandra/net/async/HandshakeProtocolTest.java deleted file mode 100644 index af486368c607..000000000000 --- a/test/unit/org/apache/cassandra/net/async/HandshakeProtocolTest.java +++ /dev/null @@ -1,95 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.cassandra.net.async; - -import org.junit.After; -import org.junit.Assert; -import org.junit.BeforeClass; -import org.junit.Test; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.PooledByteBufAllocator; -import org.apache.cassandra.config.DatabaseDescriptor; -import org.apache.cassandra.net.MessagingService; -import org.apache.cassandra.net.async.HandshakeProtocol.FirstHandshakeMessage; -import org.apache.cassandra.net.async.HandshakeProtocol.SecondHandshakeMessage; -import org.apache.cassandra.net.async.HandshakeProtocol.ThirdHandshakeMessage; -import org.apache.cassandra.utils.FBUtilities; - -import static org.junit.Assert.assertEquals; - -public class HandshakeProtocolTest -{ - private ByteBuf buf; - - @BeforeClass - public static void before() - { - // Kind of stupid, but the test trigger the initialization of the MessagingService class and that require - // DatabaseDescriptor to be configured ... - DatabaseDescriptor.daemonInitialization(); - } - - @After - public void tearDown() - { - if (buf != null && buf.refCnt() > 0) - buf.release(); - } - - @Test - public void firstMessageTest() throws Exception - { - firstMessageTest(NettyFactory.Mode.MESSAGING, false); - firstMessageTest(NettyFactory.Mode.MESSAGING, true); - firstMessageTest(NettyFactory.Mode.STREAMING, false); - firstMessageTest(NettyFactory.Mode.STREAMING, true); - } - - private void firstMessageTest(NettyFactory.Mode mode, boolean compression) throws Exception - { - FirstHandshakeMessage before = new FirstHandshakeMessage(MessagingService.current_version, mode, compression); - buf = before.encode(PooledByteBufAllocator.DEFAULT); - FirstHandshakeMessage after = FirstHandshakeMessage.maybeDecode(buf); - assertEquals(before, after); - assertEquals(before.hashCode(), after.hashCode()); - Assert.assertFalse(before.equals(null)); - } - - @Test - public void secondMessageTest() throws Exception - { - SecondHandshakeMessage before = new SecondHandshakeMessage(MessagingService.current_version); - buf = before.encode(PooledByteBufAllocator.DEFAULT); - SecondHandshakeMessage after = SecondHandshakeMessage.maybeDecode(buf); - assertEquals(before, after); - assertEquals(before.hashCode(), after.hashCode()); - Assert.assertFalse(before.equals(null)); - } - - @Test - public void thirdMessageTest() throws Exception - { - ThirdHandshakeMessage before = new ThirdHandshakeMessage(MessagingService.current_version, FBUtilities.getBroadcastAddressAndPort()); - buf = before.encode(PooledByteBufAllocator.DEFAULT); - ThirdHandshakeMessage after = ThirdHandshakeMessage.maybeDecode(buf); - assertEquals(before, after); - assertEquals(before.hashCode(), after.hashCode()); - Assert.assertFalse(before.equals(null)); - } -} \ No newline at end of file diff --git a/test/unit/org/apache/cassandra/net/async/InboundHandshakeHandlerTest.java b/test/unit/org/apache/cassandra/net/async/InboundHandshakeHandlerTest.java deleted file mode 100644 index 93a1c204554d..000000000000 --- a/test/unit/org/apache/cassandra/net/async/InboundHandshakeHandlerTest.java +++ /dev/null @@ -1,291 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.cassandra.net.async; - -import java.io.IOException; -import java.net.InetSocketAddress; -import java.net.SocketAddress; -import java.util.ArrayList; - -import com.google.common.net.InetAddresses; -import org.junit.After; -import org.junit.Assert; -import org.junit.Before; -import org.junit.BeforeClass; -import org.junit.Test; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufOutputStream; -import io.netty.buffer.PooledByteBufAllocator; -import io.netty.buffer.Unpooled; -import io.netty.channel.ChannelPipeline; -import io.netty.channel.ChannelPromise; -import io.netty.channel.embedded.EmbeddedChannel; -import io.netty.handler.codec.compression.Lz4FrameDecoder; -import io.netty.handler.codec.compression.Lz4FrameEncoder; -import org.apache.cassandra.config.DatabaseDescriptor; -import org.apache.cassandra.locator.InetAddressAndPort; -import org.apache.cassandra.net.CompactEndpointSerializationHelper; -import org.apache.cassandra.net.MessagingService; -import org.apache.cassandra.net.async.HandshakeProtocol.FirstHandshakeMessage; -import org.apache.cassandra.net.async.HandshakeProtocol.ThirdHandshakeMessage; -import org.apache.cassandra.net.async.InboundHandshakeHandler.State; - -import static org.apache.cassandra.net.async.NettyFactory.Mode.MESSAGING; - -public class InboundHandshakeHandlerTest -{ - private static final InetAddressAndPort addr = InetAddressAndPort.getByAddressOverrideDefaults(InetAddresses.forString("127.0.0.1"), 0); - private static final int MESSAGING_VERSION = MessagingService.current_version; - private static final int VERSION_30 = MessagingService.VERSION_30; - - private InboundHandshakeHandler handler; - private EmbeddedChannel channel; - private ByteBuf buf; - - @BeforeClass - public static void beforeClass() - { - DatabaseDescriptor.daemonInitialization(); - } - - @Before - public void setUp() - { - TestAuthenticator authenticator = new TestAuthenticator(false); - handler = new InboundHandshakeHandler(authenticator); - channel = new EmbeddedChannel(handler); - } - - @After - public void tearDown() - { - if (buf != null) - buf.release(); - channel.finishAndReleaseAll(); - } - - @Test - public void handleAuthenticate_Good() - { - handler = new InboundHandshakeHandler(new TestAuthenticator(true)); - channel = new EmbeddedChannel(handler); - boolean result = handler.handleAuthenticate(new InetSocketAddress(addr.address, addr.port), channel.pipeline().firstContext()); - Assert.assertTrue(result); - Assert.assertTrue(channel.isOpen()); - } - - @Test - public void handleAuthenticate_Bad() - { - boolean result = handler.handleAuthenticate(new InetSocketAddress(addr.address, addr.port), channel.pipeline().firstContext()); - Assert.assertFalse(result); - Assert.assertFalse(channel.isOpen()); - Assert.assertFalse(channel.isActive()); - } - - @Test - public void handleAuthenticate_BadSocketAddr() - { - boolean result = handler.handleAuthenticate(new FakeSocketAddress(), channel.pipeline().firstContext()); - Assert.assertFalse(result); - Assert.assertFalse(channel.isOpen()); - Assert.assertFalse(channel.isActive()); - } - - private static class FakeSocketAddress extends SocketAddress - { } - - @Test - public void decode_AlreadyFailed() - { - handler.setState(State.HANDSHAKE_FAIL); - buf = new FirstHandshakeMessage(MESSAGING_VERSION, MESSAGING, true).encode(PooledByteBufAllocator.DEFAULT); - handler.decode(channel.pipeline().firstContext(), buf, new ArrayList<>()); - Assert.assertFalse(channel.isOpen()); - Assert.assertFalse(channel.isActive()); - Assert.assertSame(State.HANDSHAKE_FAIL, handler.getState()); - } - - @Test - public void handleStart_NotEnoughInputBytes() throws IOException - { - ByteBuf buf = Unpooled.EMPTY_BUFFER; - State state = handler.handleStart(channel.pipeline().firstContext(), buf); - Assert.assertEquals(State.START, state); - Assert.assertTrue(channel.isOpen()); - Assert.assertTrue(channel.isActive()); - } - - @Test (expected = IOException.class) - public void handleStart_BadMagic() throws IOException - { - InboundHandshakeHandler handler = new InboundHandshakeHandler(new TestAuthenticator(false)); - EmbeddedChannel channel = new EmbeddedChannel(handler); - buf = Unpooled.buffer(32, 32); - - FirstHandshakeMessage first = new FirstHandshakeMessage(MESSAGING_VERSION, - MESSAGING, - true); - - buf.writeInt(MessagingService.PROTOCOL_MAGIC << 2); - buf.writeInt(first.encodeFlags()); - handler.handleStart(channel.pipeline().firstContext(), buf); - } - - @Test - public void handleStart_VersionTooHigh() throws IOException - { - channel.eventLoop(); - buf = new FirstHandshakeMessage(MESSAGING_VERSION + 1, MESSAGING, true).encode(PooledByteBufAllocator.DEFAULT); - State state = handler.handleStart(channel.pipeline().firstContext(), buf); - Assert.assertEquals(State.HANDSHAKE_FAIL, state); - Assert.assertFalse(channel.isOpen()); - Assert.assertFalse(channel.isActive()); - } - - @Test - public void handleStart_VersionLessThan3_0() throws IOException - { - buf = new FirstHandshakeMessage(VERSION_30 - 1, MESSAGING, true).encode(PooledByteBufAllocator.DEFAULT); - State state = handler.handleStart(channel.pipeline().firstContext(), buf); - Assert.assertEquals(State.HANDSHAKE_FAIL, state); - - Assert.assertFalse(channel.isOpen()); - Assert.assertFalse(channel.isActive()); - } - - @Test - public void handleStart_HappyPath_Messaging() throws IOException - { - buf = new FirstHandshakeMessage(MESSAGING_VERSION, MESSAGING, true).encode(PooledByteBufAllocator.DEFAULT); - State state = handler.handleStart(channel.pipeline().firstContext(), buf); - Assert.assertEquals(State.AWAIT_MESSAGING_START_RESPONSE, state); - if (buf.refCnt() > 0) - buf.release(); - - buf = new ThirdHandshakeMessage(MESSAGING_VERSION, addr).encode(PooledByteBufAllocator.DEFAULT); - state = handler.handleMessagingStartResponse(channel.pipeline().firstContext(), buf); - - Assert.assertEquals(State.HANDSHAKE_COMPLETE, state); - Assert.assertTrue(channel.isOpen()); - Assert.assertTrue(channel.isActive()); - Assert.assertFalse(channel.outboundMessages().isEmpty()); - channel.releaseOutbound(); - } - - @Test - public void handleMessagingStartResponse_NotEnoughInputBytes() throws IOException - { - ByteBuf buf = Unpooled.EMPTY_BUFFER; - State state = handler.handleMessagingStartResponse(channel.pipeline().firstContext(), buf); - Assert.assertEquals(State.AWAIT_MESSAGING_START_RESPONSE, state); - Assert.assertTrue(channel.isOpen()); - Assert.assertTrue(channel.isActive()); - } - - @Test - public void handleMessagingStartResponse_BadMaxVersion() throws IOException - { - buf = Unpooled.buffer(32, 32); - buf.writeInt(MESSAGING_VERSION + 1); - CompactEndpointSerializationHelper.instance.serialize(addr, new ByteBufDataOutputPlus(buf), MESSAGING_VERSION + 1); - State state = handler.handleMessagingStartResponse(channel.pipeline().firstContext(), buf); - Assert.assertEquals(State.HANDSHAKE_FAIL, state); - Assert.assertFalse(channel.isOpen()); - Assert.assertFalse(channel.isActive()); - } - - @Test - public void handleMessagingStartResponse_HappyPath() throws IOException - { - buf = Unpooled.buffer(32, 32); - buf.writeInt(MESSAGING_VERSION); - CompactEndpointSerializationHelper.instance.serialize(addr, new ByteBufDataOutputPlus(buf), MESSAGING_VERSION); - State state = handler.handleMessagingStartResponse(channel.pipeline().firstContext(), buf); - Assert.assertEquals(State.HANDSHAKE_COMPLETE, state); - Assert.assertTrue(channel.isOpen()); - Assert.assertTrue(channel.isActive()); - } - - @Test - public void setupPipeline_NoCompression() - { - ChannelPipeline pipeline = channel.pipeline(); - Assert.assertNotNull(pipeline.get(InboundHandshakeHandler.class)); - - handler.setupMessagingPipeline(pipeline, addr, false, MESSAGING_VERSION); - Assert.assertNotNull(pipeline.get(MessageInHandler.class)); - Assert.assertNull(pipeline.get(Lz4FrameDecoder.class)); - Assert.assertNull(pipeline.get(Lz4FrameEncoder.class)); - Assert.assertNull(pipeline.get(InboundHandshakeHandler.class)); - } - - @Test - public void setupPipeline_WithCompression() - { - ChannelPipeline pipeline = channel.pipeline(); - Assert.assertNotNull(pipeline.get(InboundHandshakeHandler.class)); - - handler.setupMessagingPipeline(pipeline, addr, true, MESSAGING_VERSION); - Assert.assertNotNull(pipeline.get(MessageInHandler.class)); - Assert.assertNotNull(pipeline.get(Lz4FrameDecoder.class)); - Assert.assertNull(pipeline.get(Lz4FrameEncoder.class)); - Assert.assertNull(pipeline.get(InboundHandshakeHandler.class)); - } - - @Test - public void failHandshake() - { - ChannelPromise future = channel.newPromise(); - handler.setHandshakeTimeout(future); - Assert.assertFalse(future.isCancelled()); - Assert.assertTrue(channel.isOpen()); - handler.failHandshake(channel.pipeline().firstContext()); - Assert.assertSame(State.HANDSHAKE_FAIL, handler.getState()); - Assert.assertTrue(future.isCancelled()); - Assert.assertFalse(channel.isOpen()); - } - - @Test - public void failHandshake_AlreadyConnected() - { - ChannelPromise future = channel.newPromise(); - handler.setHandshakeTimeout(future); - Assert.assertFalse(future.isCancelled()); - Assert.assertTrue(channel.isOpen()); - handler.setState(State.HANDSHAKE_COMPLETE); - handler.failHandshake(channel.pipeline().firstContext()); - Assert.assertSame(State.HANDSHAKE_COMPLETE, handler.getState()); - Assert.assertTrue(channel.isOpen()); - } - - @Test - public void failHandshake_TaskIsCancelled() - { - ChannelPromise future = channel.newPromise(); - future.cancel(false); - handler.setHandshakeTimeout(future); - handler.setState(State.AWAIT_MESSAGING_START_RESPONSE); - Assert.assertTrue(channel.isOpen()); - handler.failHandshake(channel.pipeline().firstContext()); - Assert.assertSame(State.AWAIT_MESSAGING_START_RESPONSE, handler.getState()); - Assert.assertTrue(channel.isOpen()); - } -} diff --git a/test/unit/org/apache/cassandra/net/async/MessageInHandlerTest.java b/test/unit/org/apache/cassandra/net/async/MessageInHandlerTest.java deleted file mode 100644 index 5997861ef4e7..000000000000 --- a/test/unit/org/apache/cassandra/net/async/MessageInHandlerTest.java +++ /dev/null @@ -1,339 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.cassandra.net.async; - -import java.io.EOFException; -import java.io.IOException; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collections; -import java.util.EnumMap; -import java.util.List; -import java.util.Map; -import java.util.Random; -import java.util.UUID; -import java.util.function.BiConsumer; - -import com.google.common.collect.ImmutableList; -import com.google.common.net.InetAddresses; -import com.google.common.primitives.Shorts; -import org.junit.After; -import org.junit.Assert; -import org.junit.BeforeClass; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.junit.runners.Parameterized.Parameters; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.Unpooled; -import io.netty.channel.embedded.EmbeddedChannel; -import org.apache.cassandra.config.DatabaseDescriptor; -import org.apache.cassandra.exceptions.RequestFailureReason; -import org.apache.cassandra.locator.InetAddressAndPort; -import org.apache.cassandra.net.MessageIn; -import org.apache.cassandra.net.MessageOut; -import org.apache.cassandra.net.MessagingService; -import org.apache.cassandra.net.ParameterType; -import org.apache.cassandra.utils.NanoTimeToCurrentTimeMillis; -import org.apache.cassandra.utils.UUIDGen; - -import static org.apache.cassandra.net.async.OutboundConnectionIdentifier.ConnectionType.SMALL_MESSAGE; - -@RunWith(Parameterized.class) -public class MessageInHandlerTest -{ - private static final int MSG_ID = 42; - private static InetAddressAndPort addr; - - private final int messagingVersion; - - private ByteBuf buf; - - @BeforeClass - public static void before() - { - DatabaseDescriptor.daemonInitialization(); - addr = InetAddressAndPort.getByAddress(InetAddresses.forString("127.0.73.101")); - } - - public MessageInHandlerTest(int messagingVersion) - { - this.messagingVersion = messagingVersion; - } - - @Parameters() - public static Iterable generateData() - { - return Arrays.asList(MessagingService.VERSION_30, MessagingService.VERSION_40); - } - - @After - public void tearDown() - { - if (buf != null && buf.refCnt() > 0) - buf.release(); - } - - private BaseMessageInHandler getHandler(InetAddressAndPort addr, int messagingVersion, BiConsumer messageConsumer) - { - return messagingVersion >= MessagingService.VERSION_40 ? - new MessageInHandler(addr, messagingVersion, messageConsumer) : - new MessageInHandlerPre40(addr, messagingVersion, messageConsumer); - } - - @Test(expected = AssertionError.class) - public void testBadVersionForHandler() - { - if (messagingVersion < MessagingService.VERSION_40) - new MessageInHandler(addr, messagingVersion, null); - else - new MessageInHandlerPre40(addr, messagingVersion, null); - } - - @Test - public void decode_BadMagic() - { - int len = MessageInHandler.FIRST_SECTION_BYTE_COUNT; - buf = Unpooled.buffer(len, len); - buf.writeInt(-1); - buf.writerIndex(len); - - BaseMessageInHandler handler = getHandler(addr, messagingVersion, null); - EmbeddedChannel channel = new EmbeddedChannel(handler); - Assert.assertTrue(channel.isOpen()); - channel.writeInbound(buf); - Assert.assertFalse(channel.isOpen()); - } - - @Test - public void decode_HappyPath_NoParameters() throws Exception - { - MessageInWrapper result = decode_HappyPath(Collections.emptyMap()); - Assert.assertTrue(result.messageIn.parameters.isEmpty()); - } - - @Test - public void decode_HappyPath_WithParameters() throws Exception - { - UUID uuid = UUIDGen.getTimeUUID(); - Map parameters = new EnumMap<>(ParameterType.class); - parameters.put(ParameterType.FAILURE_RESPONSE, MessagingService.ONE_BYTE); - parameters.put(ParameterType.FAILURE_REASON, Shorts.checkedCast(RequestFailureReason.READ_TOO_MANY_TOMBSTONES.code)); - parameters.put(ParameterType.TRACE_SESSION, uuid); - MessageInWrapper result = decode_HappyPath(parameters); - Assert.assertEquals(3, result.messageIn.parameters.size()); - Assert.assertTrue(result.messageIn.isFailureResponse()); - Assert.assertEquals(RequestFailureReason.READ_TOO_MANY_TOMBSTONES, result.messageIn.getFailureReason()); - Assert.assertEquals(uuid, result.messageIn.parameters.get(ParameterType.TRACE_SESSION)); - } - - private MessageInWrapper decode_HappyPath(Map parameters) throws Exception - { - MessageOut msgOut = new MessageOut<>(addr, MessagingService.Verb.ECHO, null, null, ImmutableList.of(), SMALL_MESSAGE); - for (Map.Entry param : parameters.entrySet()) - msgOut = msgOut.withParameter(param.getKey(), param.getValue()); - serialize(msgOut, MSG_ID); - - MessageInWrapper wrapper = new MessageInWrapper(); - BaseMessageInHandler handler = getHandler(addr, messagingVersion, wrapper.messageConsumer); - List out = new ArrayList<>(); - handler.decode(null, buf, out); - - Assert.assertNotNull(wrapper.messageIn); - Assert.assertEquals(MSG_ID, wrapper.id); - Assert.assertEquals(msgOut.from, wrapper.messageIn.from); - Assert.assertEquals(msgOut.verb, wrapper.messageIn.verb); - Assert.assertTrue(out.isEmpty()); - - return wrapper; - } - - private void serialize(MessageOut msgOut, int id) throws IOException - { - if (buf == null) - buf = Unpooled.buffer(1024, 1024); // 1k should be enough for everybody! - buf.writeInt(MessagingService.PROTOCOL_MAGIC); - buf.writeInt(id); // this is the id - buf.writeInt((int) NanoTimeToCurrentTimeMillis.convert(System.nanoTime())); - - msgOut.serialize(new ByteBufDataOutputPlus(buf), messagingVersion); - } - - @Test - public void decode_WithHalfReceivedParameters() throws Exception - { - MessageOut msgOut = new MessageOut<>(addr, MessagingService.Verb.ECHO, null, null, ImmutableList.of(), SMALL_MESSAGE); - UUID uuid = UUIDGen.getTimeUUID(); - msgOut = msgOut.withParameter(ParameterType.TRACE_SESSION, uuid); - - serialize(msgOut, MSG_ID); - - // move the write index pointer back a few bytes to simulate like the full bytes are not present. - // yeah, it's lame, but it tests the basics of what is happening during the deserialiization - int originalWriterIndex = buf.writerIndex(); - buf.writerIndex(originalWriterIndex - 6); - - MessageInWrapper wrapper = new MessageInWrapper(); - BaseMessageInHandler handler = getHandler(addr, messagingVersion, wrapper.messageConsumer); - List out = new ArrayList<>(); - handler.decode(null, buf, out); - - Assert.assertNull(wrapper.messageIn); - - BaseMessageInHandler.MessageHeader header = handler.getMessageHeader(); - Assert.assertEquals(MSG_ID, header.messageId); - Assert.assertEquals(msgOut.verb, header.verb); - Assert.assertEquals(msgOut.from, header.from); - Assert.assertTrue(out.isEmpty()); - - // now, set the writer index back to the original value to pretend that we actually got more bytes in - buf.writerIndex(originalWriterIndex); - handler.decode(null, buf, out); - Assert.assertNotNull(wrapper.messageIn); - Assert.assertTrue(out.isEmpty()); - } - - @Test - public void canReadNextParam_HappyPath() throws IOException - { - buildParamBufPre40(13); - Assert.assertTrue(MessageInHandlerPre40.canReadNextParam(buf)); - } - - @Test - public void canReadNextParam_OnlyFirstByte() throws IOException - { - buildParamBufPre40(13); - buf.writerIndex(1); - Assert.assertFalse(MessageInHandlerPre40.canReadNextParam(buf)); - } - - @Test - public void canReadNextParam_PartialUTF() throws IOException - { - buildParamBufPre40(13); - buf.writerIndex(5); - Assert.assertFalse(MessageInHandlerPre40.canReadNextParam(buf)); - } - - @Test - public void canReadNextParam_TruncatedValueLength() throws IOException - { - buildParamBufPre40(13); - buf.writerIndex(buf.writerIndex() - 13 - 2); - Assert.assertFalse(MessageInHandlerPre40.canReadNextParam(buf)); - } - - @Test - public void canReadNextParam_MissingLastBytes() throws IOException - { - buildParamBufPre40(13); - buf.writerIndex(buf.writerIndex() - 2); - Assert.assertFalse(MessageInHandlerPre40.canReadNextParam(buf)); - } - - private void buildParamBufPre40(int valueLength) throws IOException - { - buf = Unpooled.buffer(1024, 1024); // 1k should be enough for everybody! - - try (ByteBufDataOutputPlus output = new ByteBufDataOutputPlus(buf)) - { - output.writeUTF("name"); - byte[] array = new byte[valueLength]; - output.writeInt(array.length); - output.write(array); - } - } - - @Test - public void exceptionHandled() - { - BaseMessageInHandler handler = getHandler(addr, messagingVersion, null); - EmbeddedChannel channel = new EmbeddedChannel(handler); - Assert.assertTrue(channel.isOpen()); - handler.exceptionCaught(channel.pipeline().firstContext(), new EOFException()); - Assert.assertFalse(channel.isOpen()); - } - - /** - * this is for handling the bug uncovered by CASSANDRA-14574. - * - * TL;DR if we run into a problem processing a message out an incoming buffer (and we close the channel, etc), - * do not attempt to process anymore messages from the buffer (force the channel closed and - * reject any more read attempts from the buffer). - * - * The idea here is to put several messages into a ByteBuf, pass that to the channel/handler, and make sure that - * only the initial, correct messages in the buffer are processed. After one messages fails the rest of the buffer - * should be ignored. - */ - @Test - public void exceptionHandled_14574() throws IOException - { - Map parameters = new EnumMap<>(ParameterType.class); - parameters.put(ParameterType.FAILURE_RESPONSE, MessagingService.ONE_BYTE); - parameters.put(ParameterType.FAILURE_REASON, Shorts.checkedCast(RequestFailureReason.READ_TOO_MANY_TOMBSTONES.code)); - MessageOut msgOut = new MessageOut<>(addr, MessagingService.Verb.ECHO, null, null, ImmutableList.of(), SMALL_MESSAGE); - for (Map.Entry param : parameters.entrySet()) - msgOut = msgOut.withParameter(param.getKey(), param.getValue()); - - // put one complete, correct message into the buffer - serialize(msgOut, 1); - - // add a second message, but intentionally corrupt it by manipulating a byte in it's range - int startPosition = buf.writerIndex(); - serialize(msgOut, 2); - int positionToHack = startPosition + 2; - buf.setByte(positionToHack, buf.getByte(positionToHack) - 1); - - // add one more complete, correct message into the buffer - serialize(msgOut, 3); - - MessageIdsWrapper wrapper = new MessageIdsWrapper(); - BaseMessageInHandler handler = getHandler(addr, messagingVersion, wrapper.messageConsumer); - EmbeddedChannel channel = new EmbeddedChannel(handler); - Assert.assertTrue(channel.isOpen()); - channel.writeOneInbound(buf); - - Assert.assertFalse(buf.isReadable()); - Assert.assertEquals(BaseMessageInHandler.State.CLOSED, handler.getState()); - Assert.assertFalse(channel.isOpen()); - Assert.assertEquals(1, wrapper.ids.size()); - Assert.assertEquals(Integer.valueOf(1), wrapper.ids.get(0)); - } - - private static class MessageInWrapper - { - MessageIn messageIn; - int id; - - final BiConsumer messageConsumer = (messageIn, integer) -> - { - this.messageIn = messageIn; - this.id = integer; - }; - } - - private static class MessageIdsWrapper - { - private final ArrayList ids = new ArrayList<>(); - - final BiConsumer messageConsumer = (messageIn, integer) -> ids.add(integer); - } -} diff --git a/test/unit/org/apache/cassandra/net/async/MessageOutHandlerTest.java b/test/unit/org/apache/cassandra/net/async/MessageOutHandlerTest.java deleted file mode 100644 index 9aa26251cfd0..000000000000 --- a/test/unit/org/apache/cassandra/net/async/MessageOutHandlerTest.java +++ /dev/null @@ -1,288 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.cassandra.net.async; - -import java.io.IOException; -import java.util.HashMap; -import java.util.Optional; -import java.util.UUID; -import java.util.concurrent.ExecutionException; -import java.util.concurrent.TimeoutException; - -import org.junit.Assert; -import org.junit.Before; -import org.junit.BeforeClass; -import org.junit.Test; - -import io.netty.channel.ChannelFuture; -import io.netty.channel.ChannelHandlerContext; -import io.netty.channel.ChannelOutboundHandlerAdapter; -import io.netty.channel.ChannelPromise; -import io.netty.channel.DefaultChannelPromise; -import io.netty.channel.embedded.EmbeddedChannel; -import io.netty.handler.codec.UnsupportedMessageTypeException; -import io.netty.handler.timeout.IdleStateEvent; -import org.apache.cassandra.config.DatabaseDescriptor; -import org.apache.cassandra.io.IVersionedSerializer; -import org.apache.cassandra.io.util.DataInputPlus; -import org.apache.cassandra.io.util.DataOutputPlus; -import org.apache.cassandra.locator.InetAddressAndPort; -import org.apache.cassandra.net.MessageOut; -import org.apache.cassandra.net.MessagingService; -import org.apache.cassandra.net.ParameterType; -import org.apache.cassandra.tracing.Tracing; - -public class MessageOutHandlerTest -{ - private static final int MESSAGING_VERSION = MessagingService.current_version; - - private ChannelWriter channelWriter; - private EmbeddedChannel channel; - private MessageOutHandler handler; - - @BeforeClass - public static void before() - { - DatabaseDescriptor.daemonInitialization(); - DatabaseDescriptor.createAllDirectories(); - } - - @Before - public void setup() throws Exception - { - setup(MessageOutHandler.AUTO_FLUSH_THRESHOLD); - } - - private void setup(int flushThreshold) throws Exception - { - OutboundConnectionIdentifier connectionId = OutboundConnectionIdentifier.small(InetAddressAndPort.getByNameOverrideDefaults("127.0.0.1", 0), - InetAddressAndPort.getByNameOverrideDefaults("127.0.0.2", 0)); - OutboundMessagingConnection omc = new NonSendingOutboundMessagingConnection(connectionId, null, Optional.empty()); - channel = new EmbeddedChannel(); - channelWriter = ChannelWriter.create(channel, omc::handleMessageResult, Optional.empty()); - handler = new MessageOutHandler(connectionId, MESSAGING_VERSION, channelWriter, () -> null, flushThreshold); - channel.pipeline().addLast(handler); - } - - @Test - public void write_NoFlush() throws ExecutionException, InterruptedException, TimeoutException - { - MessageOut message = new MessageOut(MessagingService.Verb.ECHO); - ChannelFuture future = channel.write(new QueuedMessage(message, 42)); - Assert.assertTrue(!future.isDone()); - Assert.assertFalse(channel.releaseOutbound()); - } - - @Test - public void write_WithFlush() throws Exception - { - setup(1); - MessageOut message = new MessageOut(MessagingService.Verb.ECHO); - ChannelFuture future = channel.write(new QueuedMessage(message, 42)); - Assert.assertTrue(future.isSuccess()); - Assert.assertTrue(channel.releaseOutbound()); - } - - @Test - public void serializeMessage() throws IOException - { - channelWriter.pendingMessageCount.set(1); - QueuedMessage msg = new QueuedMessage(new MessageOut(MessagingService.Verb.INTERNAL_RESPONSE), 1); - ChannelFuture future = channel.writeAndFlush(msg); - - Assert.assertTrue(future.isSuccess()); - Assert.assertTrue(1 <= channel.outboundMessages().size()); - Assert.assertTrue(channel.releaseOutbound()); - } - - @Test - public void wrongMessageType() - { - ChannelPromise promise = new DefaultChannelPromise(channel); - Assert.assertFalse(handler.isMessageValid("this is the wrong message type", promise)); - - Assert.assertFalse(promise.isSuccess()); - Assert.assertNotNull(promise.cause()); - Assert.assertSame(UnsupportedMessageTypeException.class, promise.cause().getClass()); - } - - @Test - public void unexpiredMessage() - { - QueuedMessage msg = new QueuedMessage(new MessageOut(MessagingService.Verb.INTERNAL_RESPONSE), 1); - ChannelPromise promise = new DefaultChannelPromise(channel); - Assert.assertTrue(handler.isMessageValid(msg, promise)); - - // we won't know if it was successful yet, but we'll know if it's a failure because cause will be set - Assert.assertNull(promise.cause()); - } - - @Test - public void expiredMessage() - { - QueuedMessage msg = new QueuedMessage(new MessageOut(MessagingService.Verb.INTERNAL_RESPONSE), 1, 0, true, true); - ChannelPromise promise = new DefaultChannelPromise(channel); - Assert.assertFalse(handler.isMessageValid(msg, promise)); - - Assert.assertFalse(promise.isSuccess()); - Assert.assertNotNull(promise.cause()); - Assert.assertSame(ExpiredException.class, promise.cause().getClass()); - Assert.assertTrue(channel.outboundMessages().isEmpty()); - } - - @Test - public void write_MessageTooLarge() - { - write_BadMessageSize(Integer.MAX_VALUE + 1); - } - - @Test - public void write_MessageSizeIsBananas() - { - write_BadMessageSize(Integer.MIN_VALUE + 10000); - } - - private void write_BadMessageSize(long size) - { - IVersionedSerializer serializer = new IVersionedSerializer() - { - public void serialize(Object o, DataOutputPlus out, int version) - { } - - public Object deserialize(DataInputPlus in, int version) - { - return null; - } - - public long serializedSize(Object o, int version) - { - return size; - } - }; - MessageOut message = new MessageOut(MessagingService.Verb.UNUSED_5, "payload", serializer); - ChannelFuture future = channel.write(new QueuedMessage(message, 42)); - Throwable t = future.cause(); - Assert.assertNotNull(t); - Assert.assertSame(IllegalStateException.class, t.getClass()); - Assert.assertTrue(channel.isOpen()); - Assert.assertFalse(channel.releaseOutbound()); - } - - @Test - public void writeForceExceptionPath() - { - IVersionedSerializer serializer = new IVersionedSerializer() - { - public void serialize(Object o, DataOutputPlus out, int version) - { - throw new RuntimeException("this exception is part of the test - DON'T PANIC"); - } - - public Object deserialize(DataInputPlus in, int version) - { - return null; - } - - public long serializedSize(Object o, int version) - { - return 42; - } - }; - MessageOut message = new MessageOut(MessagingService.Verb.UNUSED_5, "payload", serializer); - ChannelFuture future = channel.write(new QueuedMessage(message, 42)); - Throwable t = future.cause(); - Assert.assertNotNull(t); - Assert.assertFalse(channel.isOpen()); - Assert.assertFalse(channel.releaseOutbound()); - } - - @Test - public void captureTracingInfo_ForceException() - { - MessageOut message = new MessageOut(MessagingService.Verb.INTERNAL_RESPONSE) - .withParameter(ParameterType.TRACE_SESSION, new byte[9]); - handler.captureTracingInfo(new QueuedMessage(message, 42)); - } - - @Test - public void captureTracingInfo_UnknownSession() - { - UUID uuid = UUID.randomUUID(); - MessageOut message = new MessageOut(MessagingService.Verb.INTERNAL_RESPONSE) - .withParameter(ParameterType.TRACE_SESSION, uuid); - handler.captureTracingInfo(new QueuedMessage(message, 42)); - } - - @Test - public void captureTracingInfo_KnownSession() - { - Tracing.instance.newSession(new HashMap<>()); - MessageOut message = new MessageOut(MessagingService.Verb.REQUEST_RESPONSE); - handler.captureTracingInfo(new QueuedMessage(message, 42)); - } - - @Test - public void userEventTriggered_RandomObject() - { - Assert.assertTrue(channel.isOpen()); - ChannelUserEventSender sender = new ChannelUserEventSender(); - channel.pipeline().addFirst(sender); - sender.sendEvent("ThisIsAFakeEvent"); - Assert.assertTrue(channel.isOpen()); - } - - @Test - public void userEventTriggered_Idle_NoPendingBytes() - { - Assert.assertTrue(channel.isOpen()); - ChannelUserEventSender sender = new ChannelUserEventSender(); - channel.pipeline().addFirst(sender); - sender.sendEvent(IdleStateEvent.WRITER_IDLE_STATE_EVENT); - Assert.assertTrue(channel.isOpen()); - } - - @Test - public void userEventTriggered_Idle_WithPendingBytes() - { - Assert.assertTrue(channel.isOpen()); - ChannelUserEventSender sender = new ChannelUserEventSender(); - channel.pipeline().addFirst(sender); - - MessageOut message = new MessageOut(MessagingService.Verb.INTERNAL_RESPONSE); - channel.writeOutbound(new QueuedMessage(message, 42)); - sender.sendEvent(IdleStateEvent.WRITER_IDLE_STATE_EVENT); - Assert.assertFalse(channel.isOpen()); - } - - private static class ChannelUserEventSender extends ChannelOutboundHandlerAdapter - { - private ChannelHandlerContext ctx; - - @Override - public void handlerAdded(final ChannelHandlerContext ctx) throws Exception - { - this.ctx = ctx; - } - - private void sendEvent(Object event) - { - ctx.fireUserEventTriggered(event); - } - } -} diff --git a/test/unit/org/apache/cassandra/net/async/NettyFactoryTest.java b/test/unit/org/apache/cassandra/net/async/NettyFactoryTest.java deleted file mode 100644 index 18d17e88664b..000000000000 --- a/test/unit/org/apache/cassandra/net/async/NettyFactoryTest.java +++ /dev/null @@ -1,336 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.cassandra.net.async; - -import java.net.InetSocketAddress; -import java.util.Optional; - -import com.google.common.net.InetAddresses; -import org.junit.After; -import org.junit.Assert; -import org.junit.Before; -import org.junit.BeforeClass; -import org.junit.Test; - -import io.netty.bootstrap.Bootstrap; -import io.netty.channel.Channel; -import io.netty.channel.EventLoopGroup; -import io.netty.channel.epoll.EpollEventLoopGroup; -import io.netty.channel.epoll.EpollServerSocketChannel; -import io.netty.channel.group.ChannelGroup; -import io.netty.channel.group.DefaultChannelGroup; -import io.netty.channel.nio.NioEventLoopGroup; -import io.netty.channel.socket.nio.NioServerSocketChannel; -import io.netty.channel.socket.nio.NioSocketChannel; -import io.netty.handler.ssl.SslHandler; -import io.netty.util.concurrent.GlobalEventExecutor; -import org.apache.cassandra.auth.AllowAllInternodeAuthenticator; -import org.apache.cassandra.auth.IInternodeAuthenticator; -import org.apache.cassandra.config.DatabaseDescriptor; -import org.apache.cassandra.config.EncryptionOptions; -import org.apache.cassandra.config.EncryptionOptions.ServerEncryptionOptions; -import org.apache.cassandra.exceptions.ConfigurationException; -import org.apache.cassandra.locator.InetAddressAndPort; -import org.apache.cassandra.net.MessagingService; -import org.apache.cassandra.net.async.NettyFactory.InboundInitializer; -import org.apache.cassandra.net.async.NettyFactory.OutboundInitializer; -import org.apache.cassandra.service.NativeTransportService; -import org.apache.cassandra.utils.FBUtilities; -import org.apache.cassandra.utils.NativeLibrary; - -public class NettyFactoryTest -{ - private static final InetAddressAndPort LOCAL_ADDR = InetAddressAndPort.getByAddressOverrideDefaults(InetAddresses.forString("127.0.0.1"), 9876); - private static final InetAddressAndPort REMOTE_ADDR = InetAddressAndPort.getByAddressOverrideDefaults(InetAddresses.forString("127.0.0.2"), 9876); - private static final int receiveBufferSize = 1 << 16; - private static final IInternodeAuthenticator AUTHENTICATOR = new AllowAllInternodeAuthenticator(); - private static final boolean EPOLL_AVAILABLE = NativeTransportService.useEpoll(); - - private ChannelGroup channelGroup; - private NettyFactory factory; - - @BeforeClass - public static void before() - { - DatabaseDescriptor.daemonInitialization(); - } - - @Before - public void setUp() - { - channelGroup = new DefaultChannelGroup(GlobalEventExecutor.INSTANCE); - } - - @After - public void tearDown() throws Exception - { - if (factory != null) - factory.close(); - } - - @Test - public void createServerChannel_Epoll() - { - if (!EPOLL_AVAILABLE) - return; - Channel inboundChannel = createServerChannel(true); - if (inboundChannel == null) - return; - Assert.assertEquals(EpollServerSocketChannel.class, inboundChannel.getClass()); - inboundChannel.close(); - } - - private Channel createServerChannel(boolean useEpoll) - { - InboundInitializer inboundInitializer = new InboundInitializer(AUTHENTICATOR, null, channelGroup); - factory = new NettyFactory(useEpoll); - - try - { - return factory.createInboundChannel(LOCAL_ADDR, inboundInitializer, receiveBufferSize); - } - catch (Exception e) - { - if (NativeLibrary.osType == NativeLibrary.OSType.LINUX) - throw e; - - return null; - } - } - - @Test - public void createServerChannel_Nio() - { - Channel inboundChannel = createServerChannel(false); - Assert.assertNotNull("we should always be able to get a NIO channel", inboundChannel); - Assert.assertEquals(NioServerSocketChannel.class, inboundChannel.getClass()); - inboundChannel.close(); - } - - @Test(expected = ConfigurationException.class) - public void createServerChannel_SecondAttemptToBind() - { - Channel inboundChannel = null; - try - { - InboundInitializer inboundInitializer = new InboundInitializer(AUTHENTICATOR, null, channelGroup); - inboundChannel = NettyFactory.instance.createInboundChannel(LOCAL_ADDR, inboundInitializer, receiveBufferSize); - NettyFactory.instance.createInboundChannel(LOCAL_ADDR, inboundInitializer, receiveBufferSize); - } - finally - { - if (inboundChannel != null) - inboundChannel.close(); - } - } - - @Test(expected = ConfigurationException.class) - public void createServerChannel_UnbindableAddress() - { - InetAddressAndPort addr = InetAddressAndPort.getByAddressOverrideDefaults(InetAddresses.forString("1.1.1.1"), 9876); - InboundInitializer inboundInitializer = new InboundInitializer(AUTHENTICATOR, null, channelGroup); - NettyFactory.instance.createInboundChannel(addr, inboundInitializer, receiveBufferSize); - } - - @Test - public void deterineAcceptGroupSize() - { - ServerEncryptionOptions serverEncryptionOptions = new ServerEncryptionOptions(); - serverEncryptionOptions.enabled = false; - Assert.assertEquals(1, NettyFactory.determineAcceptGroupSize(serverEncryptionOptions)); - serverEncryptionOptions.enabled = true; - Assert.assertEquals(1, NettyFactory.determineAcceptGroupSize(serverEncryptionOptions)); - - serverEncryptionOptions.enable_legacy_ssl_storage_port = true; - Assert.assertEquals(2, NettyFactory.determineAcceptGroupSize(serverEncryptionOptions)); - serverEncryptionOptions.enable_legacy_ssl_storage_port = false; - - InetAddressAndPort originalBroadcastAddr = FBUtilities.getBroadcastAddressAndPort(); - try - { - FBUtilities.setBroadcastInetAddress(InetAddresses.increment(FBUtilities.getLocalAddressAndPort().address)); - DatabaseDescriptor.setListenOnBroadcastAddress(true); - - serverEncryptionOptions.enabled = false; - Assert.assertEquals(2, NettyFactory.determineAcceptGroupSize(serverEncryptionOptions)); - serverEncryptionOptions.enabled = true; - Assert.assertEquals(2, NettyFactory.determineAcceptGroupSize(serverEncryptionOptions)); - - serverEncryptionOptions.enable_legacy_ssl_storage_port = true; - Assert.assertEquals(4, NettyFactory.determineAcceptGroupSize(serverEncryptionOptions)); - } - finally - { - FBUtilities.setBroadcastInetAddress(originalBroadcastAddr.address); - DatabaseDescriptor.setListenOnBroadcastAddress(false); - } - } - - @Test - public void getEventLoopGroup_EpollWithIoRatioBoost() - { - if (!EPOLL_AVAILABLE) - return; - getEventLoopGroup_Epoll(true); - } - - private EpollEventLoopGroup getEventLoopGroup_Epoll(boolean ioBoost) - { - EventLoopGroup eventLoopGroup; - try - { - eventLoopGroup = NettyFactory.getEventLoopGroup(true, 1, "testEventLoopGroup", ioBoost); - } - catch (Exception e) - { - if (NativeLibrary.osType == NativeLibrary.OSType.LINUX) - throw e; - - // ignore as epoll is only available on linux platforms, so don't fail the test on other OSes - return null; - } - - Assert.assertTrue(eventLoopGroup instanceof EpollEventLoopGroup); - return (EpollEventLoopGroup) eventLoopGroup; - } - - @Test - public void getEventLoopGroup_EpollWithoutIoRatioBoost() - { - if (!EPOLL_AVAILABLE) - return; - getEventLoopGroup_Epoll(false); - } - - @Test - public void getEventLoopGroup_NioWithoutIoRatioBoost() - { - getEventLoopGroup_Nio(true); - } - - private NioEventLoopGroup getEventLoopGroup_Nio(boolean ioBoost) - { - EventLoopGroup eventLoopGroup = NettyFactory.getEventLoopGroup(false, 1, "testEventLoopGroup", ioBoost); - Assert.assertTrue(eventLoopGroup instanceof NioEventLoopGroup); - return (NioEventLoopGroup) eventLoopGroup; - } - - @Test - public void getEventLoopGroup_NioWithIoRatioBoost() - { - getEventLoopGroup_Nio(true); - } - - @Test - public void createOutboundBootstrap_Epoll() - { - if (!EPOLL_AVAILABLE) - return; - Bootstrap bootstrap = createOutboundBootstrap(true); - Assert.assertEquals(EpollEventLoopGroup.class, bootstrap.config().group().getClass()); - } - - private Bootstrap createOutboundBootstrap(boolean useEpoll) - { - factory = new NettyFactory(useEpoll); - OutboundConnectionIdentifier id = OutboundConnectionIdentifier.gossip(LOCAL_ADDR, REMOTE_ADDR); - OutboundConnectionParams params = OutboundConnectionParams.builder() - .connectionId(id) - .coalescingStrategy(Optional.empty()) - .protocolVersion(MessagingService.current_version) - .build(); - return factory.createOutboundBootstrap(params); - } - - @Test - public void createOutboundBootstrap_Nio() - { - Bootstrap bootstrap = createOutboundBootstrap(false); - Assert.assertEquals(NioEventLoopGroup.class, bootstrap.config().group().getClass()); - } - - @Test - public void createInboundInitializer_WithoutSsl() throws Exception - { - ServerEncryptionOptions encryptionOptions = new ServerEncryptionOptions(); - encryptionOptions.enabled = false; - InboundInitializer initializer = new InboundInitializer(AUTHENTICATOR, encryptionOptions, channelGroup); - NioSocketChannel channel = new NioSocketChannel(); - initializer.initChannel(channel); - Assert.assertNull(channel.pipeline().get(SslHandler.class)); - Assert.assertNull(channel.pipeline().get(OptionalSslHandler.class)); - } - - private ServerEncryptionOptions encOptions() - { - ServerEncryptionOptions encryptionOptions; - encryptionOptions = new ServerEncryptionOptions(); - encryptionOptions.keystore = "test/conf/cassandra_ssl_test.keystore"; - encryptionOptions.keystore_password = "cassandra"; - encryptionOptions.truststore = "test/conf/cassandra_ssl_test.truststore"; - encryptionOptions.truststore_password = "cassandra"; - encryptionOptions.require_client_auth = false; - encryptionOptions.cipher_suites = new String[] {"TLS_RSA_WITH_AES_128_CBC_SHA"}; - return encryptionOptions; - } - - @Test - public void createInboundInitializer_WithSsl() throws Exception - { - ServerEncryptionOptions encryptionOptions = encOptions(); - encryptionOptions.enabled = true; - encryptionOptions.optional = false; - InboundInitializer initializer = new InboundInitializer(AUTHENTICATOR, encryptionOptions, channelGroup); - NioSocketChannel channel = new NioSocketChannel(); - Assert.assertNull(channel.pipeline().get(SslHandler.class)); - initializer.initChannel(channel); - Assert.assertNotNull(channel.pipeline().get(SslHandler.class)); - Assert.assertNull(channel.pipeline().get(OptionalSslHandler.class)); - } - - @Test - public void createInboundInitializer_WithOptionalSsl() throws Exception - { - ServerEncryptionOptions encryptionOptions = encOptions(); - encryptionOptions.enabled = true; - encryptionOptions.optional = true; - InboundInitializer initializer = new InboundInitializer(AUTHENTICATOR, encryptionOptions, channelGroup); - NioSocketChannel channel = new NioSocketChannel(); - Assert.assertNull(channel.pipeline().get(SslHandler.class)); - initializer.initChannel(channel); - Assert.assertNotNull(channel.pipeline().get(OptionalSslHandler.class)); - Assert.assertNull(channel.pipeline().get(SslHandler.class)); - } - - @Test - public void createOutboundInitializer_WithSsl() throws Exception - { - OutboundConnectionIdentifier id = OutboundConnectionIdentifier.gossip(LOCAL_ADDR, REMOTE_ADDR); - OutboundConnectionParams params = OutboundConnectionParams.builder() - .connectionId(id) - .encryptionOptions(encOptions()) - .protocolVersion(MessagingService.current_version) - .build(); - OutboundInitializer outboundInitializer = new OutboundInitializer(params); - NioSocketChannel channel = new NioSocketChannel(); - Assert.assertNull(channel.pipeline().get(SslHandler.class)); - outboundInitializer.initChannel(channel); - Assert.assertNotNull(channel.pipeline().get(SslHandler.class)); - } -} diff --git a/test/unit/org/apache/cassandra/net/async/NonSendingOutboundMessagingConnection.java b/test/unit/org/apache/cassandra/net/async/NonSendingOutboundMessagingConnection.java deleted file mode 100644 index b0b15b81ab9b..000000000000 --- a/test/unit/org/apache/cassandra/net/async/NonSendingOutboundMessagingConnection.java +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.cassandra.net.async; - -import java.util.Optional; - -import org.apache.cassandra.auth.AllowAllInternodeAuthenticator; -import org.apache.cassandra.config.EncryptionOptions; -import org.apache.cassandra.utils.CoalescingStrategies; - -class NonSendingOutboundMessagingConnection extends OutboundMessagingConnection -{ - boolean sendMessageInvoked; - - NonSendingOutboundMessagingConnection(OutboundConnectionIdentifier connectionId, EncryptionOptions.ServerEncryptionOptions encryptionOptions, Optional coalescingStrategy) - { - super(connectionId, encryptionOptions, coalescingStrategy, new AllowAllInternodeAuthenticator()); - } - - @Override - boolean sendMessage(QueuedMessage queuedMessage) - { - sendMessageInvoked = true; - return true; - } -} diff --git a/test/unit/org/apache/cassandra/net/async/OutboundConnectionParamsTest.java b/test/unit/org/apache/cassandra/net/async/OutboundConnectionParamsTest.java deleted file mode 100644 index 23a4a6873678..000000000000 --- a/test/unit/org/apache/cassandra/net/async/OutboundConnectionParamsTest.java +++ /dev/null @@ -1,67 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.cassandra.net.async; - -import org.junit.BeforeClass; -import org.junit.Test; - -import org.apache.cassandra.config.DatabaseDescriptor; -import org.apache.cassandra.net.MessagingService; - -public class OutboundConnectionParamsTest -{ - static int version; - - @BeforeClass - public static void before() - { - DatabaseDescriptor.daemonInitialization(); - version = MessagingService.current_version; - } - - @Test (expected = IllegalArgumentException.class) - public void build_SendSizeLessThanZero() - { - OutboundConnectionParams.builder().protocolVersion(version).sendBufferSize(-1).build(); - } - - @Test (expected = IllegalArgumentException.class) - public void build_SendSizeHuge() - { - OutboundConnectionParams.builder().protocolVersion(version).sendBufferSize(1 << 30).build(); - } - - @Test (expected = IllegalArgumentException.class) - public void build_TcpConnectTimeoutLessThanZero() - { - OutboundConnectionParams.builder().protocolVersion(version).tcpConnectTimeoutInMS(-1).build(); - } - - @Test (expected = IllegalArgumentException.class) - public void build_TcpUserTimeoutLessThanZero() - { - OutboundConnectionParams.builder().protocolVersion(version).tcpUserTimeoutInMS(-1).build(); - } - - @Test - public void build_TcpUserTimeoutEqualsZero() - { - OutboundConnectionParams.builder().protocolVersion(version).tcpUserTimeoutInMS(0).build(); - } -} diff --git a/test/unit/org/apache/cassandra/net/async/OutboundHandshakeHandlerTest.java b/test/unit/org/apache/cassandra/net/async/OutboundHandshakeHandlerTest.java deleted file mode 100644 index 172667cb30f5..000000000000 --- a/test/unit/org/apache/cassandra/net/async/OutboundHandshakeHandlerTest.java +++ /dev/null @@ -1,258 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.cassandra.net.async; - -import java.util.LinkedList; -import java.util.List; -import java.util.Optional; - -import com.google.common.net.InetAddresses; -import org.junit.After; -import org.junit.Assert; -import org.junit.Before; -import org.junit.BeforeClass; -import org.junit.Test; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.PooledByteBufAllocator; -import io.netty.buffer.Unpooled; -import io.netty.channel.ChannelOutboundHandlerAdapter; -import io.netty.channel.ChannelPipeline; -import io.netty.channel.embedded.EmbeddedChannel; -import io.netty.handler.codec.compression.Lz4FrameDecoder; -import io.netty.handler.codec.compression.Lz4FrameEncoder; -import org.apache.cassandra.config.DatabaseDescriptor; -import org.apache.cassandra.locator.InetAddressAndPort; -import org.apache.cassandra.net.MessagingService; -import org.apache.cassandra.net.async.HandshakeProtocol.SecondHandshakeMessage; -import org.apache.cassandra.net.async.HandshakeProtocol.ThirdHandshakeMessage; -import org.apache.cassandra.net.async.OutboundHandshakeHandler.HandshakeResult; - -import static org.apache.cassandra.net.async.OutboundHandshakeHandler.HandshakeResult.UNKNOWN_PROTOCOL_VERSION; - -public class OutboundHandshakeHandlerTest -{ - private static final int MESSAGING_VERSION = MessagingService.current_version; - private static final InetAddressAndPort localAddr = InetAddressAndPort.getByAddressOverrideDefaults(InetAddresses.forString("127.0.0.1"), 0); - private static final InetAddressAndPort remoteAddr = InetAddressAndPort.getByAddressOverrideDefaults(InetAddresses.forString("127.0.0.2"), 0); - private static final String HANDLER_NAME = "clientHandshakeHandler"; - - private EmbeddedChannel channel; - private OutboundHandshakeHandler handler; - private OutboundConnectionIdentifier connectionId; - private OutboundConnectionParams params; - private CallbackHandler callbackHandler; - private ByteBuf buf; - - @BeforeClass - public static void before() - { - DatabaseDescriptor.daemonInitialization(); - } - - @Before - public void setup() - { - channel = new EmbeddedChannel(new ChannelOutboundHandlerAdapter()); - connectionId = OutboundConnectionIdentifier.small(localAddr, remoteAddr); - callbackHandler = new CallbackHandler(); - params = OutboundConnectionParams.builder() - .connectionId(connectionId) - .callback(handshakeResult -> callbackHandler.receive(handshakeResult)) - .mode(NettyFactory.Mode.MESSAGING) - .protocolVersion(MessagingService.current_version) - .coalescingStrategy(Optional.empty()) - .build(); - handler = new OutboundHandshakeHandler(params); - channel.pipeline().addFirst(HANDLER_NAME, handler); - } - - @After - public void tearDown() - { - if (buf != null && buf.refCnt() > 0) - buf.release(); - Assert.assertFalse(channel.finishAndReleaseAll()); - } - - @Test - public void decode_SmallInput() throws Exception - { - buf = Unpooled.buffer(2, 2); - List out = new LinkedList<>(); - handler.decode(channel.pipeline().firstContext(), buf, out); - Assert.assertEquals(0, buf.readerIndex()); - Assert.assertTrue(out.isEmpty()); - } - - @Test - public void decode_HappyPath() - { - buf = new SecondHandshakeMessage(MESSAGING_VERSION).encode(PooledByteBufAllocator.DEFAULT); - channel.writeInbound(buf); - Assert.assertEquals(1, channel.outboundMessages().size()); - Assert.assertTrue(channel.isOpen()); - - Assert.assertEquals(MESSAGING_VERSION, callbackHandler.result.negotiatedMessagingVersion); - Assert.assertEquals(HandshakeResult.Outcome.SUCCESS, callbackHandler.result.outcome); - Assert.assertFalse(channel.outboundMessages().isEmpty()); - - ByteBuf thridMsgBuf = (ByteBuf) channel.outboundMessages().poll(); - try - { - ThirdHandshakeMessage thirdHandshakeMessage = ThirdHandshakeMessage.maybeDecode(thridMsgBuf); - Assert.assertEquals(MESSAGING_VERSION, thirdHandshakeMessage.messagingVersion); - } - finally - { - thridMsgBuf.release(); - } - } - - @Test - public void decode_HappyPathThrowsException() - { - callbackHandler.failOnCallback = true; - buf = new SecondHandshakeMessage(MESSAGING_VERSION).encode(PooledByteBufAllocator.DEFAULT); - channel.writeInbound(buf); - Assert.assertFalse(channel.isOpen()); - Assert.assertEquals(1, channel.outboundMessages().size()); - Assert.assertTrue(channel.releaseOutbound()); // throw away any responses from decode() - - Assert.assertEquals(UNKNOWN_PROTOCOL_VERSION, callbackHandler.result.negotiatedMessagingVersion); - Assert.assertEquals(HandshakeResult.Outcome.NEGOTIATION_FAILURE, callbackHandler.result.outcome); - } - - @Test - public void decode_ReceivedUnexpectedLowerMsgVersion() - { - int msgVersion = MESSAGING_VERSION - 1; - buf = new SecondHandshakeMessage(msgVersion).encode(PooledByteBufAllocator.DEFAULT); - channel.writeInbound(buf); - Assert.assertTrue(channel.inboundMessages().isEmpty()); - - Assert.assertEquals(msgVersion, callbackHandler.result.negotiatedMessagingVersion); - Assert.assertEquals(HandshakeResult.Outcome.DISCONNECT, callbackHandler.result.outcome); - Assert.assertFalse(channel.isOpen()); - Assert.assertTrue(channel.outboundMessages().isEmpty()); - } - - @Test - public void decode_ReceivedExpectedLowerMsgVersion() - { - int msgVersion = MESSAGING_VERSION - 1; - channel.pipeline().remove(HANDLER_NAME); - params = OutboundConnectionParams.builder() - .connectionId(connectionId) - .callback(handshakeResult -> callbackHandler.receive(handshakeResult)) - .mode(NettyFactory.Mode.MESSAGING) - .protocolVersion(msgVersion) - .coalescingStrategy(Optional.empty()) - .build(); - handler = new OutboundHandshakeHandler(params); - channel.pipeline().addFirst(HANDLER_NAME, handler); - - buf = new SecondHandshakeMessage(msgVersion).encode(PooledByteBufAllocator.DEFAULT); - channel.writeInbound(buf); - Assert.assertTrue(channel.inboundMessages().isEmpty()); - - Assert.assertEquals(msgVersion, callbackHandler.result.negotiatedMessagingVersion); - Assert.assertEquals(HandshakeResult.Outcome.SUCCESS, callbackHandler.result.outcome); - Assert.assertTrue(channel.isOpen()); - Assert.assertFalse(channel.outboundMessages().isEmpty()); - - ByteBuf thridMsgBuf = (ByteBuf) channel.outboundMessages().poll(); - try - { - ThirdHandshakeMessage thirdHandshakeMessage = ThirdHandshakeMessage.maybeDecode(thridMsgBuf); - Assert.assertEquals(MESSAGING_VERSION, thirdHandshakeMessage.messagingVersion); - } - finally - { - thridMsgBuf.release(); - } - } - - @Test - public void decode_ReceivedHigherMsgVersion() - { - int msgVersion = MESSAGING_VERSION - 1; - channel.pipeline().remove(HANDLER_NAME); - params = OutboundConnectionParams.builder() - .connectionId(connectionId) - .callback(handshakeResult -> callbackHandler.receive(handshakeResult)) - .mode(NettyFactory.Mode.MESSAGING) - .protocolVersion(msgVersion) - .coalescingStrategy(Optional.empty()) - .build(); - handler = new OutboundHandshakeHandler(params); - channel.pipeline().addFirst(HANDLER_NAME, handler); - buf = new SecondHandshakeMessage(MESSAGING_VERSION).encode(PooledByteBufAllocator.DEFAULT); - channel.writeInbound(buf); - - Assert.assertEquals(MESSAGING_VERSION, callbackHandler.result.negotiatedMessagingVersion); - Assert.assertEquals(HandshakeResult.Outcome.DISCONNECT, callbackHandler.result.outcome); - } - - @Test - public void setupPipeline_WithCompression() - { - EmbeddedChannel chan = new EmbeddedChannel(new ChannelOutboundHandlerAdapter()); - ChannelPipeline pipeline = chan.pipeline(); - params = OutboundConnectionParams.builder(params).compress(true).protocolVersion(MessagingService.current_version).build(); - handler = new OutboundHandshakeHandler(params); - pipeline.addFirst(handler); - handler.setupPipeline(chan, MESSAGING_VERSION); - Assert.assertNotNull(pipeline.get(Lz4FrameEncoder.class)); - Assert.assertNull(pipeline.get(Lz4FrameDecoder.class)); - Assert.assertNotNull(pipeline.get(MessageOutHandler.class)); - } - - @Test - public void setupPipeline_NoCompression() - { - EmbeddedChannel chan = new EmbeddedChannel(new ChannelOutboundHandlerAdapter()); - ChannelPipeline pipeline = chan.pipeline(); - params = OutboundConnectionParams.builder(params).compress(false).protocolVersion(MessagingService.current_version).build(); - handler = new OutboundHandshakeHandler(params); - pipeline.addFirst(handler); - handler.setupPipeline(chan, MESSAGING_VERSION); - Assert.assertNull(pipeline.get(Lz4FrameEncoder.class)); - Assert.assertNull(pipeline.get(Lz4FrameDecoder.class)); - Assert.assertNotNull(pipeline.get(MessageOutHandler.class)); - } - - private static class CallbackHandler - { - boolean failOnCallback; - HandshakeResult result; - - Void receive(HandshakeResult handshakeResult) - { - if (failOnCallback) - { - // only fail the first callback - failOnCallback = false; - throw new RuntimeException("this exception is expected in the test - DON'T PANIC"); - } - result = handshakeResult; - return null; - } - } -} diff --git a/test/unit/org/apache/cassandra/net/async/OutboundMessagingConnectionTest.java b/test/unit/org/apache/cassandra/net/async/OutboundMessagingConnectionTest.java deleted file mode 100644 index 379031c6cf63..000000000000 --- a/test/unit/org/apache/cassandra/net/async/OutboundMessagingConnectionTest.java +++ /dev/null @@ -1,521 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.cassandra.net.async; - -import java.io.IOException; -import java.net.InetAddress; -import java.util.HashMap; -import java.util.Map; -import java.util.Optional; -import java.util.concurrent.ScheduledFuture; -import java.util.concurrent.atomic.AtomicReference; -import javax.net.ssl.SSLHandshakeException; - -import com.google.common.net.InetAddresses; -import org.junit.After; -import org.junit.Assert; -import org.junit.Before; -import org.junit.BeforeClass; -import org.junit.Test; - -import io.netty.channel.ChannelFuture; -import io.netty.channel.ChannelPromise; -import io.netty.channel.embedded.EmbeddedChannel; -import org.apache.cassandra.auth.AllowAllInternodeAuthenticator; -import org.apache.cassandra.auth.IInternodeAuthenticator; -import org.apache.cassandra.config.Config; -import org.apache.cassandra.config.DatabaseDescriptor; -import org.apache.cassandra.config.EncryptionOptions.ServerEncryptionOptions; -import org.apache.cassandra.exceptions.ConfigurationException; -import org.apache.cassandra.locator.AbstractEndpointSnitch; -import org.apache.cassandra.locator.IEndpointSnitch; -import org.apache.cassandra.locator.InetAddressAndPort; -import org.apache.cassandra.locator.Replica; -import org.apache.cassandra.net.MessageOut; -import org.apache.cassandra.net.MessagingService; -import org.apache.cassandra.net.MessagingServiceTest; -import org.apache.cassandra.net.async.OutboundHandshakeHandler.HandshakeResult; -import org.apache.cassandra.net.async.OutboundMessagingConnection.State; - -import static org.apache.cassandra.net.MessagingService.Verb.ECHO; -import static org.apache.cassandra.net.async.OutboundMessagingConnection.State.CLOSED; -import static org.apache.cassandra.net.async.OutboundMessagingConnection.State.CREATING_CHANNEL; -import static org.apache.cassandra.net.async.OutboundMessagingConnection.State.NOT_READY; -import static org.apache.cassandra.net.async.OutboundMessagingConnection.State.READY; - -public class OutboundMessagingConnectionTest -{ - private static final InetAddressAndPort LOCAL_ADDR = InetAddressAndPort.getByAddressOverrideDefaults(InetAddresses.forString("127.0.0.1"), 9998); - private static final InetAddressAndPort REMOTE_ADDR = InetAddressAndPort.getByAddressOverrideDefaults(InetAddresses.forString("127.0.0.2"), 9999); - private static final InetAddressAndPort RECONNECT_ADDR = InetAddressAndPort.getByAddressOverrideDefaults(InetAddresses.forString("127.0.0.3"), 9999); - private static final int MESSAGING_VERSION = MessagingService.current_version; - - private OutboundConnectionIdentifier connectionId; - private OutboundMessagingConnection omc; - private EmbeddedChannel channel; - - private IEndpointSnitch snitch; - private ServerEncryptionOptions encryptionOptions; - - @BeforeClass - public static void before() - { - DatabaseDescriptor.daemonInitialization(); - } - - @Before - public void setup() - { - connectionId = OutboundConnectionIdentifier.small(LOCAL_ADDR, REMOTE_ADDR); - omc = new OutboundMessagingConnection(connectionId, null, Optional.empty(), new AllowAllInternodeAuthenticator()); - channel = new EmbeddedChannel(); - omc.setChannelWriter(ChannelWriter.create(channel, omc::handleMessageResult, Optional.empty())); - - snitch = DatabaseDescriptor.getEndpointSnitch(); - encryptionOptions = DatabaseDescriptor.getInternodeMessagingEncyptionOptions(); - } - - @After - public void tearDown() - { - DatabaseDescriptor.setEndpointSnitch(snitch); - DatabaseDescriptor.setInternodeMessagingEncyptionOptions(encryptionOptions); - channel.finishAndReleaseAll(); - } - - @Test - public void sendMessage_CreatingChannel() - { - Assert.assertEquals(0, omc.backlogSize()); - omc.setState(CREATING_CHANNEL); - Assert.assertTrue(omc.sendMessage(new MessageOut<>(ECHO), 1)); - Assert.assertEquals(1, omc.backlogSize()); - Assert.assertEquals(1, omc.getPendingMessages().intValue()); - } - - @Test - public void sendMessage_HappyPath() - { - Assert.assertEquals(0, omc.backlogSize()); - omc.setState(READY); - Assert.assertTrue(omc.sendMessage(new MessageOut<>(ECHO), 1)); - Assert.assertEquals(0, omc.backlogSize()); - Assert.assertTrue(channel.releaseOutbound()); - } - - @Test - public void sendMessage_Closed() - { - Assert.assertEquals(0, omc.backlogSize()); - omc.setState(CLOSED); - Assert.assertFalse(omc.sendMessage(new MessageOut<>(ECHO), 1)); - Assert.assertEquals(0, omc.backlogSize()); - Assert.assertFalse(channel.releaseOutbound()); - } - - @Test - public void shouldCompressConnection_None() - { - DatabaseDescriptor.setInternodeCompression(Config.InternodeCompression.none); - Assert.assertFalse(OutboundMessagingConnection.shouldCompressConnection(LOCAL_ADDR, REMOTE_ADDR)); - } - - @Test - public void shouldCompressConnection_All() - { - DatabaseDescriptor.setInternodeCompression(Config.InternodeCompression.all); - Assert.assertTrue(OutboundMessagingConnection.shouldCompressConnection(LOCAL_ADDR, REMOTE_ADDR)); - } - - @Test - public void shouldCompressConnection_SameDc() - { - TestSnitch snitch = new TestSnitch(); - snitch.add(LOCAL_ADDR, "dc1"); - snitch.add(REMOTE_ADDR, "dc1"); - DatabaseDescriptor.setEndpointSnitch(snitch); - DatabaseDescriptor.setInternodeCompression(Config.InternodeCompression.dc); - Assert.assertFalse(OutboundMessagingConnection.shouldCompressConnection(LOCAL_ADDR, REMOTE_ADDR)); - } - - private static class TestSnitch extends AbstractEndpointSnitch - { - private Map nodeToDc = new HashMap<>(); - - void add(InetAddressAndPort node, String dc) - { - nodeToDc.put(node, dc); - } - - public String getRack(InetAddressAndPort endpoint) - { - return null; - } - - public String getDatacenter(InetAddressAndPort endpoint) - { - return nodeToDc.get(endpoint); - } - - public int compareEndpoints(InetAddressAndPort target, Replica a1, Replica a2) - { - return 0; - } - } - - @Test - public void shouldCompressConnection_DifferentDc() - { - TestSnitch snitch = new TestSnitch(); - snitch.add(LOCAL_ADDR, "dc1"); - snitch.add(REMOTE_ADDR, "dc2"); - DatabaseDescriptor.setEndpointSnitch(snitch); - DatabaseDescriptor.setInternodeCompression(Config.InternodeCompression.dc); - Assert.assertTrue(OutboundMessagingConnection.shouldCompressConnection(LOCAL_ADDR, REMOTE_ADDR)); - } - - @Test - public void close_softClose() - { - close(true); - } - - @Test - public void close_hardClose() - { - close(false); - } - - private void close(boolean softClose) - { - int count = 32; - for (int i = 0; i < count; i++) - omc.addToBacklog(new QueuedMessage(new MessageOut<>(ECHO), i)); - Assert.assertEquals(count, omc.backlogSize()); - Assert.assertEquals(count, omc.getPendingMessages().intValue()); - - ScheduledFuture connectionTimeoutFuture = new TestScheduledFuture(); - Assert.assertFalse(connectionTimeoutFuture.isCancelled()); - omc.setConnectionTimeoutFuture(connectionTimeoutFuture); - ChannelWriter channelWriter = ChannelWriter.create(channel, omc::handleMessageResult, Optional.empty()); - omc.setChannelWriter(channelWriter); - - omc.close(softClose); - Assert.assertFalse(channel.isActive()); - Assert.assertEquals(State.CLOSED, omc.getState()); - Assert.assertEquals(0, omc.backlogSize()); - Assert.assertEquals(0, omc.getPendingMessages().intValue()); - int sentMessages = channel.outboundMessages().size(); - - if (softClose) - Assert.assertTrue(count <= sentMessages); - else - Assert.assertEquals(0, sentMessages); - Assert.assertTrue(connectionTimeoutFuture.isCancelled()); - Assert.assertTrue(channelWriter.isClosed()); - } - - @Test - public void connect_IInternodeAuthFail() - { - IInternodeAuthenticator auth = new IInternodeAuthenticator() - { - public boolean authenticate(InetAddress remoteAddress, int remotePort) - { - return false; - } - - public void validateConfiguration() throws ConfigurationException - { - - } - }; - - MessageOut messageOut = new MessageOut(MessagingService.Verb.GOSSIP_DIGEST_ACK); - OutboundMessagingPool pool = new OutboundMessagingPool(REMOTE_ADDR, LOCAL_ADDR, null, - new MessagingServiceTest.MockBackPressureStrategy(null).newState(REMOTE_ADDR), auth); - omc = pool.getConnection(messageOut); - Assert.assertSame(State.NOT_READY, omc.getState()); - Assert.assertFalse(omc.connect()); - } - - @Test - public void connect_ConnectionAlreadyStarted() - { - omc.setState(State.CREATING_CHANNEL); - Assert.assertFalse(omc.connect()); - Assert.assertSame(State.CREATING_CHANNEL, omc.getState()); - } - - @Test - public void connect_ConnectionClosed() - { - omc.setState(State.CLOSED); - Assert.assertFalse(omc.connect()); - Assert.assertSame(State.CLOSED, omc.getState()); - } - - @Test - public void connectionTimeout_StateIsReady() - { - omc.setState(READY); - ChannelFuture channelFuture = channel.newPromise(); - Assert.assertFalse(omc.connectionTimeout(channelFuture)); - Assert.assertEquals(READY, omc.getState()); - } - - @Test - public void connectionTimeout_StateIsClosed() - { - omc.setState(CLOSED); - ChannelFuture channelFuture = channel.newPromise(); - Assert.assertTrue(omc.connectionTimeout(channelFuture)); - Assert.assertEquals(CLOSED, omc.getState()); - } - - @Test - public void connectionTimeout_AssumeConnectionTimedOut() - { - int count = 32; - for (int i = 0; i < count; i++) - omc.addToBacklog(new QueuedMessage(new MessageOut<>(ECHO), i)); - Assert.assertEquals(count, omc.backlogSize()); - Assert.assertEquals(count, omc.getPendingMessages().intValue()); - - omc.setState(CREATING_CHANNEL); - ChannelFuture channelFuture = channel.newPromise(); - Assert.assertTrue(omc.connectionTimeout(channelFuture)); - Assert.assertEquals(NOT_READY, omc.getState()); - Assert.assertEquals(0, omc.backlogSize()); - Assert.assertEquals(0, omc.getPendingMessages().intValue()); - } - - @Test - public void connectCallback_FutureIsSuccess() - { - ChannelPromise promise = channel.newPromise(); - promise.setSuccess(); - Assert.assertTrue(omc.connectCallback(promise)); - } - - @Test - public void connectCallback_Closed() - { - ChannelPromise promise = channel.newPromise(); - omc.setState(State.CLOSED); - Assert.assertFalse(omc.connectCallback(promise)); - } - - @Test - public void connectCallback_FailCauseIsSslHandshake() - { - ChannelPromise promise = channel.newPromise(); - promise.setFailure(new SSLHandshakeException("test is only a test")); - Assert.assertFalse(omc.connectCallback(promise)); - Assert.assertSame(State.NOT_READY, omc.getState()); - } - - @Test - public void connectCallback_FailCauseIsNPE() - { - ChannelPromise promise = channel.newPromise(); - promise.setFailure(new NullPointerException("test is only a test")); - Assert.assertFalse(omc.connectCallback(promise)); - Assert.assertSame(State.NOT_READY, omc.getState()); - } - - @Test - public void connectCallback_FailCauseIsIOException() - { - ChannelPromise promise = channel.newPromise(); - promise.setFailure(new IOException("test is only a test")); - Assert.assertFalse(omc.connectCallback(promise)); - Assert.assertSame(State.NOT_READY, omc.getState()); - } - - @Test - public void connectCallback_FailedAndItsClosed() - { - ChannelPromise promise = channel.newPromise(); - promise.setFailure(new IOException("test is only a test")); - omc.setState(CLOSED); - Assert.assertFalse(omc.connectCallback(promise)); - Assert.assertSame(State.CLOSED, omc.getState()); - } - - @Test - public void finishHandshake_GOOD() - { - ChannelWriter channelWriter = ChannelWriter.create(channel, omc::handleMessageResult, Optional.empty()); - HandshakeResult result = HandshakeResult.success(channelWriter, MESSAGING_VERSION); - ScheduledFuture connectionTimeoutFuture = new TestScheduledFuture(); - Assert.assertFalse(connectionTimeoutFuture.isCancelled()); - - omc.setChannelWriter(null); - omc.setConnectionTimeoutFuture(connectionTimeoutFuture); - omc.finishHandshake(result); - Assert.assertFalse(channelWriter.isClosed()); - Assert.assertEquals(channelWriter, omc.getChannelWriter()); - Assert.assertEquals(READY, omc.getState()); - Assert.assertEquals(MESSAGING_VERSION, MessagingService.instance().getVersion(REMOTE_ADDR)); - Assert.assertNull(omc.getConnectionTimeoutFuture()); - Assert.assertTrue(connectionTimeoutFuture.isCancelled()); - } - - @Test - public void finishHandshake_GOOD_ButClosed() - { - ChannelWriter channelWriter = ChannelWriter.create(channel, omc::handleMessageResult, Optional.empty()); - HandshakeResult result = HandshakeResult.success(channelWriter, MESSAGING_VERSION); - ScheduledFuture connectionTimeoutFuture = new TestScheduledFuture(); - Assert.assertFalse(connectionTimeoutFuture.isCancelled()); - - omc.setChannelWriter(null); - omc.setState(CLOSED); - omc.setConnectionTimeoutFuture(connectionTimeoutFuture); - omc.finishHandshake(result); - Assert.assertTrue(channelWriter.isClosed()); - Assert.assertNull(omc.getChannelWriter()); - Assert.assertEquals(CLOSED, omc.getState()); - Assert.assertEquals(MESSAGING_VERSION, MessagingService.instance().getVersion(REMOTE_ADDR)); - Assert.assertNull(omc.getConnectionTimeoutFuture()); - Assert.assertTrue(connectionTimeoutFuture.isCancelled()); - } - - @Test - public void finishHandshake_DISCONNECT() - { - int count = 32; - for (int i = 0; i < count; i++) - omc.addToBacklog(new QueuedMessage(new MessageOut<>(ECHO), i)); - Assert.assertEquals(count, omc.backlogSize()); - - HandshakeResult result = HandshakeResult.disconnect(MESSAGING_VERSION); - omc.finishHandshake(result); - Assert.assertNotNull(omc.getChannelWriter()); - Assert.assertEquals(CREATING_CHANNEL, omc.getState()); - Assert.assertEquals(MESSAGING_VERSION, MessagingService.instance().getVersion(REMOTE_ADDR)); - Assert.assertEquals(count, omc.backlogSize()); - } - - @Test - public void finishHandshake_CONNECT_FAILURE() - { - int count = 32; - for (int i = 0; i < count; i++) - omc.addToBacklog(new QueuedMessage(new MessageOut<>(ECHO), i)); - Assert.assertEquals(count, omc.backlogSize()); - - HandshakeResult result = HandshakeResult.failed(); - omc.finishHandshake(result); - Assert.assertEquals(NOT_READY, omc.getState()); - Assert.assertEquals(MESSAGING_VERSION, MessagingService.instance().getVersion(REMOTE_ADDR)); - Assert.assertEquals(0, omc.backlogSize()); - } - - @Test - public void setStateIfNotClosed_AlreadyClosed() - { - AtomicReference state = new AtomicReference<>(CLOSED); - OutboundMessagingConnection.setStateIfNotClosed(state, NOT_READY); - Assert.assertEquals(CLOSED, state.get()); - } - - @Test - public void setStateIfNotClosed_NotClosed() - { - AtomicReference state = new AtomicReference<>(READY); - OutboundMessagingConnection.setStateIfNotClosed(state, NOT_READY); - Assert.assertEquals(NOT_READY, state.get()); - } - - @Test - public void reconnectWithNewIp_HappyPath() - { - ChannelWriter channelWriter = ChannelWriter.create(channel, omc::handleMessageResult, Optional.empty()); - omc.setChannelWriter(channelWriter); - omc.setState(READY); - OutboundConnectionIdentifier originalId = omc.getConnectionId(); - omc.reconnectWithNewIp(RECONNECT_ADDR); - Assert.assertFalse(omc.getConnectionId().equals(originalId)); - Assert.assertTrue(channelWriter.isClosed()); - Assert.assertNotSame(CLOSED, omc.getState()); - } - - @Test - public void reconnectWithNewIp_Closed() - { - omc.setState(CLOSED); - OutboundConnectionIdentifier originalId = omc.getConnectionId(); - omc.reconnectWithNewIp(RECONNECT_ADDR); - Assert.assertSame(omc.getConnectionId(), originalId); - Assert.assertSame(CLOSED, omc.getState()); - } - - @Test - public void reconnectWithNewIp_UnsedConnection() - { - omc.setState(NOT_READY); - OutboundConnectionIdentifier originalId = omc.getConnectionId(); - omc.reconnectWithNewIp(RECONNECT_ADDR); - Assert.assertNotSame(omc.getConnectionId(), originalId); - Assert.assertSame(NOT_READY, omc.getState()); - } - - @Test - public void maybeUpdateConnectionId_NoEncryption() - { - OutboundConnectionIdentifier connectionId = omc.getConnectionId(); - int version = omc.getTargetVersion(); - omc.maybeUpdateConnectionId(); - Assert.assertEquals(connectionId, omc.getConnectionId()); - Assert.assertEquals(version, omc.getTargetVersion()); - } - - @Test - public void maybeUpdateConnectionId_SameVersion() - { - ServerEncryptionOptions encryptionOptions = new ServerEncryptionOptions(); - omc = new OutboundMessagingConnection(connectionId, encryptionOptions, Optional.empty(), new AllowAllInternodeAuthenticator()); - OutboundConnectionIdentifier connectionId = omc.getConnectionId(); - int version = omc.getTargetVersion(); - omc.maybeUpdateConnectionId(); - Assert.assertEquals(connectionId, omc.getConnectionId()); - Assert.assertEquals(version, omc.getTargetVersion()); - } - - @Test - public void maybeUpdateConnectionId_3_X_Version() - { - ServerEncryptionOptions encryptionOptions = new ServerEncryptionOptions(); - encryptionOptions.enabled = true; - encryptionOptions.internode_encryption = ServerEncryptionOptions.InternodeEncryption.all; - DatabaseDescriptor.setInternodeMessagingEncyptionOptions(encryptionOptions); - omc = new OutboundMessagingConnection(connectionId, encryptionOptions, Optional.empty(), new AllowAllInternodeAuthenticator()); - int peerVersion = MessagingService.VERSION_30; - MessagingService.instance().setVersion(connectionId.remote(), MessagingService.VERSION_30); - - OutboundConnectionIdentifier connectionId = omc.getConnectionId(); - omc.maybeUpdateConnectionId(); - Assert.assertNotEquals(connectionId, omc.getConnectionId()); - Assert.assertEquals(InetAddressAndPort.getByAddressOverrideDefaults(REMOTE_ADDR.address, DatabaseDescriptor.getSSLStoragePort()), omc.getConnectionId().remote()); - Assert.assertEquals(InetAddressAndPort.getByAddressOverrideDefaults(REMOTE_ADDR.address, DatabaseDescriptor.getSSLStoragePort()), omc.getConnectionId().connectionAddress()); - Assert.assertEquals(peerVersion, omc.getTargetVersion()); - } -} diff --git a/test/unit/org/apache/cassandra/net/async/OutboundMessagingPoolTest.java b/test/unit/org/apache/cassandra/net/async/OutboundMessagingPoolTest.java deleted file mode 100644 index ecd8697ff339..000000000000 --- a/test/unit/org/apache/cassandra/net/async/OutboundMessagingPoolTest.java +++ /dev/null @@ -1,150 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.cassandra.net.async; - -import java.util.ArrayList; -import java.util.List; - -import com.google.common.net.InetAddresses; -import org.junit.After; -import org.junit.Assert; -import org.junit.Before; -import org.junit.BeforeClass; -import org.junit.Test; - -import org.apache.cassandra.auth.AllowAllInternodeAuthenticator; -import org.apache.cassandra.config.DatabaseDescriptor; -import org.apache.cassandra.db.WriteResponse; -import org.apache.cassandra.gms.GossipDigestSyn; -import org.apache.cassandra.io.IVersionedSerializer; -import org.apache.cassandra.io.util.DataInputPlus; -import org.apache.cassandra.io.util.DataOutputPlus; -import org.apache.cassandra.locator.InetAddressAndPort; -import org.apache.cassandra.net.BackPressureState; -import org.apache.cassandra.net.MessageOut; -import org.apache.cassandra.net.MessagingService; -import org.apache.cassandra.net.async.OutboundConnectionIdentifier.ConnectionType; - -public class OutboundMessagingPoolTest -{ - private static final InetAddressAndPort LOCAL_ADDR = InetAddressAndPort.getByAddressOverrideDefaults(InetAddresses.forString("127.0.0.1"), 9476); - private static final InetAddressAndPort REMOTE_ADDR = InetAddressAndPort.getByAddressOverrideDefaults(InetAddresses.forString("127.0.0.2"), 9476); - private static final InetAddressAndPort RECONNECT_ADDR = InetAddressAndPort.getByAddressOverrideDefaults(InetAddresses.forString("127.0.0.3"), 9476); - private static final List INTERNODE_MESSAGING_CONN_TYPES = new ArrayList() - {{ add(ConnectionType.GOSSIP); add(ConnectionType.LARGE_MESSAGE); add(ConnectionType.SMALL_MESSAGE); }}; - - private OutboundMessagingPool pool; - - @BeforeClass - public static void before() - { - DatabaseDescriptor.daemonInitialization(); - } - - @Before - public void setup() - { - BackPressureState backPressureState = DatabaseDescriptor.getBackPressureStrategy().newState(REMOTE_ADDR); - pool = new OutboundMessagingPool(REMOTE_ADDR, LOCAL_ADDR, null, backPressureState, new AllowAllInternodeAuthenticator()); - } - - @After - public void tearDown() - { - if (pool != null) - pool.close(false); - } - - @Test - public void getConnection_Gossip() - { - GossipDigestSyn syn = new GossipDigestSyn("cluster", "partitioner", new ArrayList<>(0)); - MessageOut message = new MessageOut<>(MessagingService.Verb.GOSSIP_DIGEST_SYN, - syn, GossipDigestSyn.serializer); - Assert.assertEquals(ConnectionType.GOSSIP, pool.getConnection(message).getConnectionId().type()); - } - - @Test - public void getConnection_SmallMessage() - { - MessageOut message = WriteResponse.createMessage(); - Assert.assertEquals(ConnectionType.SMALL_MESSAGE, pool.getConnection(message).getConnectionId().type()); - } - - @Test - public void getConnection_LargeMessage() - { - // just need a serializer to report a size, as fake as it may be - IVersionedSerializer serializer = new IVersionedSerializer() - { - public void serialize(Object o, DataOutputPlus out, int version) - { - - } - - public Object deserialize(DataInputPlus in, int version) - { - return null; - } - - public long serializedSize(Object o, int version) - { - return OutboundMessagingPool.LARGE_MESSAGE_THRESHOLD + 1; - } - }; - MessageOut message = new MessageOut<>(MessagingService.Verb.UNUSED_5, "payload", serializer); - Assert.assertEquals(ConnectionType.LARGE_MESSAGE, pool.getConnection(message).getConnectionId().type()); - } - - @Test - public void close() - { - for (ConnectionType type : INTERNODE_MESSAGING_CONN_TYPES) - Assert.assertNotSame(OutboundMessagingConnection.State.CLOSED, pool.getConnection(type).getState()); - pool.close(false); - for (ConnectionType type : INTERNODE_MESSAGING_CONN_TYPES) - Assert.assertEquals(OutboundMessagingConnection.State.CLOSED, pool.getConnection(type).getState()); - } - - @Test - public void reconnectWithNewIp() - { - for (ConnectionType type : INTERNODE_MESSAGING_CONN_TYPES) - { - Assert.assertEquals(REMOTE_ADDR, pool.getPreferredRemoteAddr()); - Assert.assertEquals(REMOTE_ADDR, pool.getConnection(type).getConnectionId().connectionAddress()); - } - - pool.reconnectWithNewIp(RECONNECT_ADDR); - - for (ConnectionType type : INTERNODE_MESSAGING_CONN_TYPES) - { - Assert.assertEquals(RECONNECT_ADDR, pool.getPreferredRemoteAddr()); - Assert.assertEquals(RECONNECT_ADDR, pool.getConnection(type).getConnectionId().connectionAddress()); - } - } - - @Test - public void timeoutCounter() - { - long originalValue = pool.getTimeouts(); - pool.incrementTimeout(); - Assert.assertEquals(originalValue + 1, pool.getTimeouts()); - } -} diff --git a/test/unit/org/apache/cassandra/net/proxy/InboundProxyHandler.java b/test/unit/org/apache/cassandra/net/proxy/InboundProxyHandler.java new file mode 100644 index 000000000000..7e3b004041e1 --- /dev/null +++ b/test/unit/org/apache/cassandra/net/proxy/InboundProxyHandler.java @@ -0,0 +1,234 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.net.proxy; + +import java.util.ArrayDeque; +import java.util.concurrent.TimeUnit; +import java.util.function.Function; + +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.util.concurrent.EventExecutor; +import io.netty.util.concurrent.ScheduledFuture; + +public class InboundProxyHandler extends ChannelInboundHandlerAdapter +{ + private final ArrayDeque forwardQueue; + private ScheduledFuture scheduled = null; + private final Controller controller; + public InboundProxyHandler(Controller controller) + { + this.controller = controller; + this.forwardQueue = new ArrayDeque<>(1024); + } + + @Override + public void channelActive(ChannelHandlerContext ctx) throws Exception + { + super.channelActive(ctx); + ctx.read(); + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + controller.onDisconnect.run(); + + if (scheduled != null) + { + scheduled.cancel(true); + scheduled = null; + } + + if (!forwardQueue.isEmpty()) + forwardQueue.clear(); + + super.channelInactive(ctx); + } + + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) + { + Forward forward = controller.forwardStrategy.forward(ctx, msg); + forwardQueue.offer(forward); + maybeScheduleNext(ctx.channel().eventLoop()); + controller.onRead.run(); + ctx.channel().read(); + } + + private void maybeScheduleNext(EventExecutor executor) + { + if (forwardQueue.isEmpty()) + { + // Ran out of items to process + scheduled = null; + } + else if (scheduled == null) + { + // Schedule next available or let the last in line schedule it + Forward forward = forwardQueue.poll(); + scheduled = forward.schedule(executor); + scheduled.addListener((e) -> { + scheduled = null; + maybeScheduleNext(executor); + }); + } + } + + private static class Forward + { + final long arrivedAt; + final long latency; + final Runnable handler; + + private Forward(long arrivedAt, long latency, Runnable handler) + { + this.arrivedAt = arrivedAt; + this.latency = latency; + this.handler = handler; + } + + ScheduledFuture schedule(EventExecutor executor) + { + long now = System.currentTimeMillis(); + long elapsed = now - arrivedAt; + long runIn = latency - elapsed; + + if (runIn > 0) + return executor.schedule(handler, runIn, TimeUnit.MILLISECONDS); + else + return executor.schedule(handler, 0, TimeUnit.MILLISECONDS); + } + } + + private static class ForwardNormally implements ForwardStrategy + { + static ForwardNormally instance = new ForwardNormally(); + + public Forward forward(ChannelHandlerContext ctx, Object msg) + { + return new Forward(System.currentTimeMillis(), + 0, + () -> ctx.fireChannelRead(msg)); + } + } + + public interface ForwardStrategy + { + public Forward forward(ChannelHandlerContext ctx, Object msg); + } + + private static class ForwardWithLatency implements ForwardStrategy + { + private final long latency; + private final TimeUnit timeUnit; + + ForwardWithLatency(long latency, TimeUnit timeUnit) + { + this.latency = latency; + this.timeUnit = timeUnit; + } + + public Forward forward(ChannelHandlerContext ctx, Object msg) + { + return new Forward(System.currentTimeMillis(), + timeUnit.toMillis(latency), + () -> ctx.fireChannelRead(msg)); + } + } + + private static class CloseAfterRead implements ForwardStrategy + { + private final Runnable afterClose; + + CloseAfterRead(Runnable afterClose) + { + this.afterClose = afterClose; + } + + public Forward forward(ChannelHandlerContext ctx, Object msg) + { + return new Forward(System.currentTimeMillis(), + 0, + () -> { + ctx.channel().close().syncUninterruptibly(); + afterClose.run(); + }); + } + } + + private static class TransformPayload implements ForwardStrategy + { + private final Function fn; + + TransformPayload(Function fn) + { + this.fn = fn; + } + + public Forward forward(ChannelHandlerContext ctx, Object msg) + { + return new Forward(System.currentTimeMillis(), + 0, + () -> ctx.fireChannelRead(fn.apply((T) msg))); + } + } + + public static class Controller + { + private volatile InboundProxyHandler.ForwardStrategy forwardStrategy; + private volatile Runnable onRead = () -> {}; + private volatile Runnable onDisconnect = () -> {}; + + public Controller() + { + this.forwardStrategy = ForwardNormally.instance; + } + public void onRead(Runnable onRead) + { + this.onRead = onRead; + } + + public void onDisconnect(Runnable onDisconnect) + { + this.onDisconnect = onDisconnect; + } + + public void reset() + { + this.forwardStrategy = ForwardNormally.instance; + } + + public void withLatency(long latency, TimeUnit timeUnit) + { + this.forwardStrategy = new ForwardWithLatency(latency, timeUnit); + } + + public void withCloseAfterRead(Runnable afterClose) + { + this.forwardStrategy = new CloseAfterRead(afterClose); + } + + public void withPayloadTransform(Function fn) + { + this.forwardStrategy = new TransformPayload<>(fn); + } + } + +} diff --git a/test/unit/org/apache/cassandra/net/proxy/ProxyHandlerTest.java b/test/unit/org/apache/cassandra/net/proxy/ProxyHandlerTest.java new file mode 100644 index 000000000000..d070f5630fff --- /dev/null +++ b/test/unit/org/apache/cassandra/net/proxy/ProxyHandlerTest.java @@ -0,0 +1,222 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.net.proxy; + +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicLong; +import java.util.function.Consumer; + +import org.junit.Assert; +import org.junit.Test; + +import io.netty.bootstrap.Bootstrap; +import io.netty.bootstrap.ServerBootstrap; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.Channel; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.ChannelOption; +import io.netty.channel.EventLoopGroup; +import io.netty.channel.local.LocalAddress; +import io.netty.channel.local.LocalChannel; +import io.netty.channel.local.LocalServerChannel; +import io.netty.channel.nio.NioEventLoopGroup; +import io.netty.handler.logging.LogLevel; +import io.netty.handler.logging.LoggingHandler; + +public class ProxyHandlerTest +{ + private final Object PAYLOAD = new Object(); + + @Test + public void testLatency() throws Throwable + { + test((proxyHandler, testHandler, channel) -> { + int count = 1; + long latency = 100; + CountDownLatch latch = new CountDownLatch(count); + long start = System.currentTimeMillis(); + testHandler.onRead = new Consumer() + { + int last = -1; + public void accept(Object o) + { + // Make sure that order is preserved + Assert.assertEquals(last + 1, o); + last = (int) o; + + long elapsed = System.currentTimeMillis() - start; + Assert.assertTrue("Latency was:" + elapsed, elapsed > latency); + latch.countDown(); + } + }; + + proxyHandler.withLatency(latency, TimeUnit.MILLISECONDS); + + for (int i = 0; i < count; i++) + { + ByteBuf bb = Unpooled.buffer(Integer.BYTES); + bb.writeInt(i); + channel.writeAndFlush(i); + } + + Assert.assertTrue(latch.await(10, TimeUnit.SECONDS)); + }); + } + + @Test + public void testNormalDelivery() throws Throwable + { + test((proxyHandler, testHandler, channelPipeline) -> { + int count = 10; + CountDownLatch latch = new CountDownLatch(count); + AtomicLong end = new AtomicLong(); + testHandler.onRead = (o) -> { + end.set(System.currentTimeMillis()); + latch.countDown(); + }; + + for (int i = 0; i < count; i++) + channelPipeline.writeAndFlush(PAYLOAD); + Assert.assertTrue(latch.await(10, TimeUnit.SECONDS)); + + }); + } + + @Test + public void testLatencyForMany() throws Throwable + { + class Event { + private final long latency; + private final long start; + private final int idx; + + Event(long latency, int idx) + { + this.latency = latency; + this.start = System.currentTimeMillis(); + this.idx = idx; + } + } + + test((proxyHandler, testHandler, channel) -> { + int count = 150; + CountDownLatch latch = new CountDownLatch(count); + AtomicInteger counter = new AtomicInteger(); + testHandler.onRead = new Consumer() + { + int lastSeen = -1; + public void accept(Object o) + { + Event e = (Event) o; + Assert.assertEquals(lastSeen + 1, e.idx); + lastSeen = e.idx; + long elapsed = System.currentTimeMillis() - e.start; + Assert.assertTrue(elapsed >= e.latency); + counter.incrementAndGet(); + latch.countDown(); + } + }; + + int idx = 0; + for (int i = 0; i < count / 3; i++) + { + for (long latency : new long[]{ 100, 200, 0 }) + { + proxyHandler.withLatency(latency, TimeUnit.MILLISECONDS); + CountDownLatch read = new CountDownLatch(1); + proxyHandler.onRead(read::countDown); + channel.writeAndFlush(new Event(latency, idx++)); + Assert.assertTrue(read.await(10, TimeUnit.SECONDS)); + } + } + + Assert.assertTrue(latch.await(10, TimeUnit.SECONDS)); + Assert.assertEquals(counter.get(), count); + }); + } + + private interface DoTest + { + public void doTest(InboundProxyHandler.Controller proxy, TestHandler testHandler, Channel channel) throws Throwable; + } + + + public void test(DoTest test) throws Throwable + { + EventLoopGroup serverGroup = new NioEventLoopGroup(1); + EventLoopGroup clientGroup = new NioEventLoopGroup(1); + + InboundProxyHandler.Controller controller = new InboundProxyHandler.Controller(); + InboundProxyHandler proxyHandler = new InboundProxyHandler(controller); + TestHandler testHandler = new TestHandler(); + + ServerBootstrap sb = new ServerBootstrap(); + sb.group(serverGroup) + .channel(LocalServerChannel.class) + .childHandler(new ChannelInitializer() { + @Override + public void initChannel(LocalChannel ch) + { + ch.pipeline() + .addLast(proxyHandler) + .addLast(testHandler); + } + }) + .childOption(ChannelOption.AUTO_READ, false); + + Bootstrap cb = new Bootstrap(); + cb.group(clientGroup) + .channel(LocalChannel.class) + .handler(new ChannelInitializer() { + @Override + public void initChannel(LocalChannel ch) throws Exception { + ch.pipeline() + .addLast(new LoggingHandler(LogLevel.TRACE)); + } + }); + + final LocalAddress addr = new LocalAddress("test"); + + Channel serverChannel = sb.bind(addr).sync().channel(); + + Channel clientChannel = cb.connect(addr).sync().channel(); + test.doTest(controller, testHandler, clientChannel); + + clientChannel.close(); + serverChannel.close(); + serverGroup.shutdownGracefully(); + clientGroup.shutdownGracefully(); + } + + + public static class TestHandler extends ChannelInboundHandlerAdapter + { + private Consumer onRead = (o) -> {}; + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) + { + onRead.accept(msg); + } + } +} diff --git a/test/unit/org/apache/cassandra/repair/LocalSyncTaskTest.java b/test/unit/org/apache/cassandra/repair/LocalSyncTaskTest.java index 903a273178dd..443d59ecaca2 100644 --- a/test/unit/org/apache/cassandra/repair/LocalSyncTaskTest.java +++ b/test/unit/org/apache/cassandra/repair/LocalSyncTaskTest.java @@ -137,14 +137,12 @@ public void testDifference() throws Throwable LocalSyncTask task = new LocalSyncTask(desc, r1.endpoint, r2.endpoint, MerkleTrees.difference(r1.trees, r2.trees), NO_PENDING_REPAIR, true, true, PreviewKind.NONE); DefaultConnectionFactory.MAX_CONNECT_ATTEMPTS = 1; - DefaultConnectionFactory.MAX_WAIT_TIME_NANOS = TimeUnit.SECONDS.toNanos(2); try { task.run(); } finally { - DefaultConnectionFactory.MAX_WAIT_TIME_NANOS = TimeUnit.SECONDS.toNanos(30); DefaultConnectionFactory.MAX_CONNECT_ATTEMPTS = 3; } @@ -240,10 +238,6 @@ private MerkleTrees createInitialTree(RepairJobDesc desc, IPartitioner partition MerkleTrees tree = new MerkleTrees(partitioner); tree.addMerkleTrees((int) Math.pow(2, 15), desc.ranges); tree.init(); - for (MerkleTree.TreeRange r : tree.invalids()) - { - r.ensureHashInitialised(); - } return tree; } diff --git a/test/unit/org/apache/cassandra/repair/RepairJobTest.java b/test/unit/org/apache/cassandra/repair/RepairJobTest.java index 6db29dc720ca..068544d897bf 100644 --- a/test/unit/org/apache/cassandra/repair/RepairJobTest.java +++ b/test/unit/org/apache/cassandra/repair/RepairJobTest.java @@ -52,10 +52,9 @@ import org.apache.cassandra.dht.Range; import org.apache.cassandra.dht.Token; import org.apache.cassandra.locator.InetAddressAndPort; -import org.apache.cassandra.net.IMessageSink; -import org.apache.cassandra.net.MessageIn; -import org.apache.cassandra.net.MessageOut; +import org.apache.cassandra.net.Message; import org.apache.cassandra.net.MessagingService; +import org.apache.cassandra.net.Verb; import org.apache.cassandra.repair.messages.RepairMessage; import org.apache.cassandra.repair.messages.SyncRequest; import org.apache.cassandra.schema.KeyspaceParams; @@ -156,7 +155,8 @@ public void setup() public void reset() { ActiveRepairService.instance.terminateSessions(); - MessagingService.instance().clearMessageSinks(); + MessagingService.instance().outboundSink.clear(); + MessagingService.instance().inboundSink.clear(); FBUtilities.reset(); } @@ -167,11 +167,11 @@ public void reset() public void testEndToEndNoDifferences() throws InterruptedException, ExecutionException, TimeoutException { Map mockTrees = new HashMap<>(); - mockTrees.put(FBUtilities.getBroadcastAddressAndPort(), createInitialTree(false)); + mockTrees.put(addr1, createInitialTree(false)); mockTrees.put(addr2, createInitialTree(false)); mockTrees.put(addr3, createInitialTree(false)); - List observedMessages = new ArrayList<>(); + List> observedMessages = new ArrayList<>(); interceptRepairMessages(mockTrees, observedMessages); job.run(); @@ -182,14 +182,14 @@ public void testEndToEndNoDifferences() throws InterruptedException, ExecutionEx assertEquals(0, result.stats.size()); // RepairJob should send out SNAPSHOTS -> VALIDATIONS -> done - List expectedTypes = new ArrayList<>(); + List expectedTypes = new ArrayList<>(); for (int i = 0; i < 3; i++) - expectedTypes.add(RepairMessage.Type.SNAPSHOT); + expectedTypes.add(Verb.SNAPSHOT_MSG); for (int i = 0; i < 3; i++) - expectedTypes.add(RepairMessage.Type.VALIDATION_REQUEST); + expectedTypes.add(Verb.VALIDATION_REQ); assertEquals(expectedTypes, observedMessages.stream() - .map(k -> ((RepairMessage) k.payload).messageType) + .map(Message::verb) .collect(Collectors.toList())); } @@ -208,7 +208,7 @@ public void testNoTreesRetainedAfterDifference() throws Throwable List mockTreeResponses = mockTrees.entrySet().stream() .map(e -> new TreeResponse(e.getKey(), e.getValue())) .collect(Collectors.toList()); - List messages = new ArrayList<>(); + List> messages = new ArrayList<>(); interceptRepairMessages(mockTrees, messages); long singleTreeSize = ObjectSizes.measureDeep(mockTrees.get(addr1)); @@ -252,7 +252,7 @@ public void testNoTreesRetainedAfterDifference() throws Throwable assertTrue(results.stream().allMatch(s -> s.numberOfDifferences == 1)); assertEquals(2, messages.size()); - assertTrue(messages.stream().allMatch(m -> ((RepairMessage) m.payload).messageType == RepairMessage.Type.SYNC_REQUEST)); + assertTrue(messages.stream().allMatch(m -> m.verb() == Verb.SYNC_REQ)); } @Test @@ -775,10 +775,6 @@ private MerkleTrees createInitialTree(boolean invalidate) MerkleTrees tree = new MerkleTrees(MURMUR3_PARTITIONER); tree.addMerkleTrees((int) Math.pow(2, 15), fullRange); tree.init(); - for (MerkleTree.TreeRange r : tree.invalids()) - { - r.ensureHashInitialised(); - } if (invalidate) { @@ -792,49 +788,36 @@ private MerkleTrees createInitialTree(boolean invalidate) } private void interceptRepairMessages(Map mockTrees, - List messageCapture) + List> messageCapture) { - MessagingService.instance().addMessageSink(new IMessageSink() - { - public boolean allowOutgoingMessage(MessageOut message, int id, InetAddressAndPort to) - { - if (message == null || !(message.payload instanceof RepairMessage)) - return false; - - // So different Thread's messages don't overwrite each other. - synchronized (messageLock) - { - messageCapture.add(message); - } - - RepairMessage rm = (RepairMessage) message.payload; - switch (rm.messageType) - { - case SNAPSHOT: - MessageIn messageIn = MessageIn.create(to, null, - Collections.emptyMap(), - MessagingService.Verb.REQUEST_RESPONSE, - MessagingService.current_version); - MessagingService.instance().receive(messageIn, id); - break; - case VALIDATION_REQUEST: - session.validationComplete(sessionJobDesc, to, mockTrees.get(to)); - break; - case SYNC_REQUEST: - SyncRequest syncRequest = (SyncRequest) rm; - session.syncComplete(sessionJobDesc, new SyncNodePair(syncRequest.src, syncRequest.dst), - true, Collections.emptyList()); - break; - default: - break; - } + MessagingService.instance().inboundSink.add(message -> message.verb().isResponse()); + MessagingService.instance().outboundSink.add((message, to) -> { + if (message == null || !(message.payload instanceof RepairMessage)) return false; + + // So different Thread's messages don't overwrite each other. + synchronized (messageLock) + { + messageCapture.add(message); } - public boolean allowIncomingMessage(MessageIn message, int id) + switch (message.verb()) { - return message.verb == MessagingService.Verb.REQUEST_RESPONSE; + case SNAPSHOT_MSG: + MessagingService.instance().callbacks.removeAndRespond(message.id(), to, message.emptyResponse()); + break; + case VALIDATION_REQ: + session.validationComplete(sessionJobDesc, to, mockTrees.get(to)); + break; + case SYNC_REQ: + SyncRequest syncRequest = (SyncRequest) message.payload; + session.syncComplete(sessionJobDesc, new SyncNodePair(syncRequest.src, syncRequest.dst), + true, Collections.emptyList()); + break; + default: + break; } + return false; }); } } diff --git a/test/unit/org/apache/cassandra/repair/SymmetricRemoteSyncTaskTest.java b/test/unit/org/apache/cassandra/repair/SymmetricRemoteSyncTaskTest.java index 7f48788bc6fa..cba64ae17f5f 100644 --- a/test/unit/org/apache/cassandra/repair/SymmetricRemoteSyncTaskTest.java +++ b/test/unit/org/apache/cassandra/repair/SymmetricRemoteSyncTaskTest.java @@ -48,7 +48,7 @@ public InstrumentedSymmetricRemoteSyncTask(InetAddressAndPort e1, InetAddressAnd InetAddressAndPort sentTo = null; @Override - void sendRequest(RepairMessage request, InetAddressAndPort to) + void sendRequest(SyncRequest request, InetAddressAndPort to) { Assert.assertNull(sentMessage); Assert.assertNotNull(request); diff --git a/test/unit/org/apache/cassandra/repair/ValidatorTest.java b/test/unit/org/apache/cassandra/repair/ValidatorTest.java index ff6b11c94d49..20f50ed3fce6 100644 --- a/test/unit/org/apache/cassandra/repair/ValidatorTest.java +++ b/test/unit/org/apache/cassandra/repair/ValidatorTest.java @@ -40,6 +40,7 @@ import org.apache.cassandra.SchemaLoader; import org.apache.cassandra.locator.InetAddressAndPort; +import org.apache.cassandra.net.Verb; import org.apache.cassandra.schema.Schema; import org.apache.cassandra.db.BufferDecoratedKey; import org.apache.cassandra.db.ColumnFamilyStore; @@ -48,19 +49,15 @@ import org.apache.cassandra.dht.IPartitioner; import org.apache.cassandra.dht.Range; import org.apache.cassandra.dht.Token; -import org.apache.cassandra.net.IMessageSink; -import org.apache.cassandra.net.MessageIn; -import org.apache.cassandra.net.MessageOut; +import org.apache.cassandra.net.Message; import org.apache.cassandra.net.MessagingService; -import org.apache.cassandra.repair.messages.RepairMessage; -import org.apache.cassandra.repair.messages.ValidationComplete; +import org.apache.cassandra.repair.messages.ValidationResponse; import org.apache.cassandra.schema.KeyspaceParams; import org.apache.cassandra.service.ActiveRepairService; import org.apache.cassandra.streaming.PreviewKind; import org.apache.cassandra.utils.ByteBufferUtil; import org.apache.cassandra.utils.MerkleTree; import org.apache.cassandra.utils.MerkleTrees; -import org.apache.cassandra.utils.FBUtilities; import org.apache.cassandra.utils.UUIDGen; import static org.junit.Assert.assertEquals; @@ -92,7 +89,7 @@ public static void defineSchema() throws Exception @After public void tearDown() { - MessagingService.instance().clearMessageSinks(); + MessagingService.instance().outboundSink.clear(); DatabaseDescriptor.setRepairSessionSpaceInMegabytes(testSizeMegabytes); } @@ -108,7 +105,7 @@ public void testValidatorComplete() throws Throwable Range range = new Range<>(partitioner.getMinimumToken(), partitioner.getRandomToken()); final RepairJobDesc desc = new RepairJobDesc(UUID.randomUUID(), UUID.randomUUID(), keyspace, columnFamily, Arrays.asList(range)); - final CompletableFuture outgoingMessageSink = registerOutgoingMessageSink(); + final CompletableFuture outgoingMessageSink = registerOutgoingMessageSink(); InetAddressAndPort remote = InetAddressAndPort.getByName("127.0.0.2"); @@ -131,13 +128,12 @@ public void testValidatorComplete() throws Throwable Token min = tree.partitioner().getMinimumToken(); assertNotNull(tree.hash(new Range<>(min, min))); - MessageOut message = outgoingMessageSink.get(TEST_TIMEOUT, TimeUnit.SECONDS); - assertEquals(MessagingService.Verb.REPAIR_MESSAGE, message.verb); - RepairMessage m = (RepairMessage) message.payload; - assertEquals(RepairMessage.Type.VALIDATION_COMPLETE, m.messageType); + Message message = outgoingMessageSink.get(TEST_TIMEOUT, TimeUnit.SECONDS); + assertEquals(Verb.VALIDATION_RSP, message.verb()); + ValidationResponse m = (ValidationResponse) message.payload; assertEquals(desc, m.desc); - assertTrue(((ValidationComplete) m).success()); - assertNotNull(((ValidationComplete) m).trees); + assertTrue(m.success()); + assertNotNull(m.trees); } @@ -147,20 +143,19 @@ public void testValidatorFailed() throws Throwable Range range = new Range<>(partitioner.getMinimumToken(), partitioner.getRandomToken()); final RepairJobDesc desc = new RepairJobDesc(UUID.randomUUID(), UUID.randomUUID(), keyspace, columnFamily, Arrays.asList(range)); - final CompletableFuture outgoingMessageSink = registerOutgoingMessageSink(); + final CompletableFuture outgoingMessageSink = registerOutgoingMessageSink(); InetAddressAndPort remote = InetAddressAndPort.getByName("127.0.0.2"); Validator validator = new Validator(desc, remote, 0, PreviewKind.NONE); validator.fail(); - MessageOut message = outgoingMessageSink.get(TEST_TIMEOUT, TimeUnit.SECONDS); - assertEquals(MessagingService.Verb.REPAIR_MESSAGE, message.verb); - RepairMessage m = (RepairMessage) message.payload; - assertEquals(RepairMessage.Type.VALIDATION_COMPLETE, m.messageType); + Message message = outgoingMessageSink.get(TEST_TIMEOUT, TimeUnit.SECONDS); + assertEquals(Verb.VALIDATION_RSP, message.verb()); + ValidationResponse m = (ValidationResponse) message.payload; assertEquals(desc, m.desc); - assertFalse(((ValidationComplete) m).success()); - assertNull(((ValidationComplete) m).trees); + assertFalse(m.success()); + assertNull(m.trees); } @Test @@ -192,7 +187,7 @@ public void simpleValidationTest(int n) throws Exception CompactionsTest.populate(keyspace, columnFamily, 0, n, 0); //ttl=3s - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); assertEquals(1, cfs.getLiveSSTables().size()); // wait enough to force single compaction @@ -204,28 +199,28 @@ public void simpleValidationTest(int n) throws Exception cfs.getTableName(), Collections.singletonList(new Range<>(sstable.first.getToken(), sstable.last.getToken()))); - ActiveRepairService.instance.registerParentRepairSession(repairSessionId, FBUtilities.getBroadcastAddressAndPort(), + InetAddressAndPort host = InetAddressAndPort.getByName("127.0.0.2"); + + ActiveRepairService.instance.registerParentRepairSession(repairSessionId, host, Collections.singletonList(cfs), desc.ranges, false, ActiveRepairService.UNREPAIRED_SSTABLE, false, PreviewKind.NONE); - final CompletableFuture outgoingMessageSink = registerOutgoingMessageSink(); - Validator validator = new Validator(desc, FBUtilities.getBroadcastAddressAndPort(), 0, true, false, PreviewKind.NONE); + final CompletableFuture outgoingMessageSink = registerOutgoingMessageSink(); + Validator validator = new Validator(desc, host, 0, true, false, PreviewKind.NONE); ValidationManager.instance.submitValidation(cfs, validator); - MessageOut message = outgoingMessageSink.get(TEST_TIMEOUT, TimeUnit.SECONDS); - assertEquals(MessagingService.Verb.REPAIR_MESSAGE, message.verb); - RepairMessage m = (RepairMessage) message.payload; - assertEquals(RepairMessage.Type.VALIDATION_COMPLETE, m.messageType); + Message message = outgoingMessageSink.get(TEST_TIMEOUT, TimeUnit.SECONDS); + assertEquals(Verb.VALIDATION_RSP, message.verb()); + ValidationResponse m = (ValidationResponse) message.payload; assertEquals(desc, m.desc); - assertTrue(((ValidationComplete) m).success()); - MerkleTrees trees = ((ValidationComplete) m).trees; + assertTrue(m.success()); - Iterator, MerkleTree>> iterator = trees.iterator(); + Iterator, MerkleTree>> iterator = m.trees.iterator(); while (iterator.hasNext()) { assertEquals(Math.pow(2, Math.ceil(Math.log(n) / Math.log(2))), iterator.next().getValue().size(), 0.0); } - assertEquals(trees.rowCount(), n); + assertEquals(m.trees.rowCount(), n); } /* @@ -249,7 +244,7 @@ public void testSizeLimiting() throws Exception // 2 ** 14 rows would normally use 2^14 leaves, but with only 1 meg we should only use 2^12 CompactionsTest.populate(keyspace, columnFamily, 0, 1 << 14, 0); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); assertEquals(1, cfs.getLiveSSTables().size()); // wait enough to force single compaction @@ -261,16 +256,18 @@ public void testSizeLimiting() throws Exception cfs.getTableName(), Collections.singletonList(new Range<>(sstable.first.getToken(), sstable.last.getToken()))); - ActiveRepairService.instance.registerParentRepairSession(repairSessionId, FBUtilities.getBroadcastAddressAndPort(), + InetAddressAndPort host = InetAddressAndPort.getByName("127.0.0.2"); + + ActiveRepairService.instance.registerParentRepairSession(repairSessionId, host, Collections.singletonList(cfs), desc.ranges, false, ActiveRepairService.UNREPAIRED_SSTABLE, false, PreviewKind.NONE); - final CompletableFuture outgoingMessageSink = registerOutgoingMessageSink(); - Validator validator = new Validator(desc, FBUtilities.getBroadcastAddressAndPort(), 0, true, false, PreviewKind.NONE); + final CompletableFuture outgoingMessageSink = registerOutgoingMessageSink(); + Validator validator = new Validator(desc, host, 0, true, false, PreviewKind.NONE); ValidationManager.instance.submitValidation(cfs, validator); - MessageOut message = outgoingMessageSink.get(TEST_TIMEOUT, TimeUnit.SECONDS); - MerkleTrees trees = ((ValidationComplete) message.payload).trees; + Message message = outgoingMessageSink.get(TEST_TIMEOUT, TimeUnit.SECONDS); + MerkleTrees trees = ((ValidationResponse) message.payload).trees; Iterator, MerkleTree>> iterator = trees.iterator(); int numTrees = 0; @@ -306,7 +303,7 @@ public void testRangeSplittingTreeSizeLimit() throws Exception // 2 ** 14 rows would normally use 2^14 leaves, but with only 1 meg we should only use 2^12 CompactionsTest.populate(keyspace, columnFamily, 0, 1 << 14, 0); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); assertEquals(1, cfs.getLiveSSTables().size()); // wait enough to force single compaction @@ -321,16 +318,18 @@ public void testRangeSplittingTreeSizeLimit() throws Exception final RepairJobDesc desc = new RepairJobDesc(repairSessionId, UUIDGen.getTimeUUID(), cfs.keyspace.getName(), cfs.getTableName(), ranges); - ActiveRepairService.instance.registerParentRepairSession(repairSessionId, FBUtilities.getBroadcastAddressAndPort(), + InetAddressAndPort host = InetAddressAndPort.getByName("127.0.0.2"); + + ActiveRepairService.instance.registerParentRepairSession(repairSessionId, host, Collections.singletonList(cfs), desc.ranges, false, ActiveRepairService.UNREPAIRED_SSTABLE, false, PreviewKind.NONE); - final CompletableFuture outgoingMessageSink = registerOutgoingMessageSink(); - Validator validator = new Validator(desc, FBUtilities.getBroadcastAddressAndPort(), 0, true, false, PreviewKind.NONE); + final CompletableFuture outgoingMessageSink = registerOutgoingMessageSink(); + Validator validator = new Validator(desc, host, 0, true, false, PreviewKind.NONE); ValidationManager.instance.submitValidation(cfs, validator); - MessageOut message = outgoingMessageSink.get(TEST_TIMEOUT, TimeUnit.SECONDS); - MerkleTrees trees = ((ValidationComplete) message.payload).trees; + Message message = outgoingMessageSink.get(TEST_TIMEOUT, TimeUnit.SECONDS); + MerkleTrees trees = ((ValidationResponse) message.payload).trees; // Should have 4 trees each with a depth of on average 10 (since each range should have gotten 0.25 megabytes) Iterator, MerkleTree>> iterator = trees.iterator(); @@ -410,22 +409,10 @@ public void testCountingHasher() assertEquals(len, ((Validator.CountingHasher)hashers[0]).getCount()); } - private CompletableFuture registerOutgoingMessageSink() + private CompletableFuture registerOutgoingMessageSink() { - final CompletableFuture future = new CompletableFuture<>(); - MessagingService.instance().addMessageSink(new IMessageSink() - { - public boolean allowOutgoingMessage(MessageOut message, int id, InetAddressAndPort to) - { - future.complete(message); - return false; - } - - public boolean allowIncomingMessage(MessageIn message, int id) - { - return false; - } - }); + final CompletableFuture future = new CompletableFuture<>(); + MessagingService.instance().outboundSink.add((message, to) -> future.complete(message)); return future; } } diff --git a/test/unit/org/apache/cassandra/repair/asymmetric/DifferenceHolderTest.java b/test/unit/org/apache/cassandra/repair/asymmetric/DifferenceHolderTest.java index 96930108d634..88810180eb96 100644 --- a/test/unit/org/apache/cassandra/repair/asymmetric/DifferenceHolderTest.java +++ b/test/unit/org/apache/cassandra/repair/asymmetric/DifferenceHolderTest.java @@ -31,6 +31,7 @@ import org.apache.cassandra.dht.Token; import org.apache.cassandra.locator.InetAddressAndPort; import org.apache.cassandra.repair.TreeResponse; +import org.apache.cassandra.utils.HashingUtils; import org.apache.cassandra.utils.MerkleTree; import org.apache.cassandra.utils.MerkleTrees; import org.apache.cassandra.utils.MerkleTreesTest; @@ -40,6 +41,11 @@ public class DifferenceHolderTest { + private static byte[] digest(String string) + { + return HashingUtils.newMessageDigest("SHA-256").digest(string.getBytes()); + } + @Test public void testFromEmptyMerkleTrees() throws UnknownHostException { @@ -74,9 +80,9 @@ public void testFromMismatchedMerkleTrees() throws UnknownHostException mt1.init(); mt2.init(); // add dummy hashes to both trees - for (MerkleTree.TreeRange range : mt1.invalids()) + for (MerkleTree.TreeRange range : mt1.rangeIterator()) range.addAll(new MerkleTreesTest.HIterator(range.right)); - for (MerkleTree.TreeRange range : mt2.invalids()) + for (MerkleTree.TreeRange range : mt2.rangeIterator()) range.addAll(new MerkleTreesTest.HIterator(range.right)); MerkleTree.TreeRange leftmost = null; @@ -85,14 +91,14 @@ public void testFromMismatchedMerkleTrees() throws UnknownHostException mt1.maxsize(fullRange, maxsize + 2); // give some room for splitting // split the leftmost - Iterator ranges = mt1.invalids(); + Iterator ranges = mt1.rangeIterator(); leftmost = ranges.next(); mt1.split(leftmost.right); // set the hashes for the leaf of the created split middle = mt1.get(leftmost.right); - middle.hash("arbitrary!".getBytes()); - mt1.get(partitioner.midpoint(leftmost.left, leftmost.right)).hash("even more arbitrary!".getBytes()); + middle.hash(digest("arbitrary!")); + mt1.get(partitioner.midpoint(leftmost.left, leftmost.right)).hash(digest("even more arbitrary!")); TreeResponse tr1 = new TreeResponse(a1, mt1); TreeResponse tr2 = new TreeResponse(a2, mt2); diff --git a/test/unit/org/apache/cassandra/repair/consistent/CoordinatorMessagingTest.java b/test/unit/org/apache/cassandra/repair/consistent/CoordinatorMessagingTest.java index b532abdcd573..c9fd9132c832 100644 --- a/test/unit/org/apache/cassandra/repair/consistent/CoordinatorMessagingTest.java +++ b/test/unit/org/apache/cassandra/repair/consistent/CoordinatorMessagingTest.java @@ -42,10 +42,11 @@ import org.apache.cassandra.cql3.statements.schema.CreateTableStatement; import org.apache.cassandra.db.ColumnFamilyStore; import org.apache.cassandra.locator.InetAddressAndPort; -import org.apache.cassandra.net.MessageIn; +import org.apache.cassandra.net.Message; import org.apache.cassandra.net.MessagingService; import org.apache.cassandra.net.MockMessagingService; import org.apache.cassandra.net.MockMessagingSpy; +import org.apache.cassandra.net.Verb; import org.apache.cassandra.repair.AbstractRepairTest; import org.apache.cassandra.repair.RepairSessionResult; import org.apache.cassandra.repair.messages.FailSession; @@ -60,7 +61,6 @@ import org.apache.cassandra.service.ActiveRepairService; import static org.apache.cassandra.net.MockMessagingService.all; -import static org.apache.cassandra.net.MockMessagingService.payload; import static org.apache.cassandra.net.MockMessagingService.to; import static org.apache.cassandra.net.MockMessagingService.verb; import static org.junit.Assert.fail; @@ -256,7 +256,7 @@ public void testMockedMessagingPrepareTimeout() throws InterruptedException, Exe // expected } // we won't send out any fail session message in case of timeouts - spyPrepare.expectMockedMessageIn(2).get(100, TimeUnit.MILLISECONDS); + spyPrepare.expectMockedMessage(2).get(100, TimeUnit.MILLISECONDS); sendFailSessionUnexpectedSpy.interceptNoMsg(100, TimeUnit.MILLISECONDS); Assert.assertFalse(repairSubmitted.get()); Assert.assertFalse(hasFailures.get()); @@ -274,52 +274,36 @@ private MockMessagingSpy createPrepareSpy(Collection failed, Collection timeout, Function sessionIdFunc) { - return MockMessagingService.when( - all(verb(MessagingService.Verb.REPAIR_MESSAGE), - payload((p) -> p instanceof PrepareConsistentRequest)) - ).respond((msgOut, to) -> - { - if(timeout.contains(to)) return null; - else return MessageIn.create(to, - new PrepareConsistentResponse(sessionIdFunc.apply((PrepareConsistentRequest) msgOut.payload), to, !failed.contains(to)), - Collections.emptyMap(), - MessagingService.Verb.REPAIR_MESSAGE, - MessagingService.current_version); - }); + return MockMessagingService.when(verb(Verb.PREPARE_CONSISTENT_REQ)).respond((msgOut, to) -> + { + if (timeout.contains(to)) + return null; + + return Message.out(Verb.PREPARE_CONSISTENT_RSP, + new PrepareConsistentResponse(sessionIdFunc.apply((PrepareConsistentRequest) msgOut.payload), to, !failed.contains(to))); + }); } private MockMessagingSpy createFinalizeSpy(Collection failed, Collection timeout) { - return MockMessagingService.when( - all(verb(MessagingService.Verb.REPAIR_MESSAGE), - payload((p) -> p instanceof FinalizePropose)) - ).respond((msgOut, to) -> - { - if(timeout.contains(to)) return null; - else return MessageIn.create(to, - new FinalizePromise(((FinalizePropose) msgOut.payload).sessionID, to, !failed.contains(to)), - Collections.emptyMap(), - MessagingService.Verb.REPAIR_MESSAGE, - MessagingService.current_version); - }); + return MockMessagingService.when(verb(Verb.FINALIZE_PROPOSE_MSG)).respond((msgOut, to) -> + { + if (timeout.contains(to)) + return null; + + return Message.out(Verb.FINALIZE_PROMISE_MSG, new FinalizePromise(((FinalizePropose) msgOut.payload).sessionID, to, !failed.contains(to))); + }); } private MockMessagingSpy createCommitSpy() { - return MockMessagingService.when( - all(verb(MessagingService.Verb.REPAIR_MESSAGE), - payload((p) -> p instanceof FinalizeCommit)) - ).dontReply(); + return MockMessagingService.when(verb(Verb.FINALIZE_COMMIT_MSG)).dontReply(); } private MockMessagingSpy createFailSessionSpy(Collection participants) { - return MockMessagingService.when( - all(verb(MessagingService.Verb.REPAIR_MESSAGE), - payload((p) -> p instanceof FailSession), - to(participants::contains)) - ).dontReply(); + return MockMessagingService.when(all(verb(Verb.FAILED_SESSION_MSG), to(participants::contains))).dontReply(); } private static RepairSessionResult createResult(CoordinatorSession coordinator) diff --git a/test/unit/org/apache/cassandra/repair/consistent/CoordinatorSessionTest.java b/test/unit/org/apache/cassandra/repair/consistent/CoordinatorSessionTest.java index c6980fe00862..1cee312a0e2f 100644 --- a/test/unit/org/apache/cassandra/repair/consistent/CoordinatorSessionTest.java +++ b/test/unit/org/apache/cassandra/repair/consistent/CoordinatorSessionTest.java @@ -35,6 +35,7 @@ import org.junit.Test; import org.apache.cassandra.locator.InetAddressAndPort; +import org.apache.cassandra.net.Message; import org.apache.cassandra.repair.AbstractRepairTest; import org.apache.cassandra.repair.RepairSessionResult; import org.apache.cassandra.repair.messages.FailSession; @@ -93,13 +94,13 @@ public InstrumentedCoordinatorSession(Builder builder) Map> sentMessages = new HashMap<>(); - protected void sendMessage(InetAddressAndPort destination, RepairMessage message) + protected void sendMessage(InetAddressAndPort destination, Message message) { if (!sentMessages.containsKey(destination)) { sentMessages.put(destination, new ArrayList<>()); } - sentMessages.get(destination).add(message); + sentMessages.get(destination).add(message.payload); } Runnable onSetRepairing = null; diff --git a/test/unit/org/apache/cassandra/repair/consistent/LocalSessionTest.java b/test/unit/org/apache/cassandra/repair/consistent/LocalSessionTest.java index a6b4fe26837d..15fd1fc096e9 100644 --- a/test/unit/org/apache/cassandra/repair/consistent/LocalSessionTest.java +++ b/test/unit/org/apache/cassandra/repair/consistent/LocalSessionTest.java @@ -44,6 +44,7 @@ import org.apache.cassandra.cql3.QueryProcessor; import org.apache.cassandra.cql3.statements.schema.CreateTableStatement; import org.apache.cassandra.locator.RangesAtEndpoint; +import org.apache.cassandra.net.Message; import org.apache.cassandra.repair.AbstractRepairTest; import org.apache.cassandra.locator.InetAddressAndPort; import org.apache.cassandra.repair.KeyspaceRepairManager; @@ -123,13 +124,14 @@ private static void assertMessagesSent(InstrumentedLocalSessions sessions, InetA static class InstrumentedLocalSessions extends LocalSessions { Map> sentMessages = new HashMap<>(); - protected void sendMessage(InetAddressAndPort destination, RepairMessage message) + + protected void sendMessage(InetAddressAndPort destination, Message message) { if (!sentMessages.containsKey(destination)) { sentMessages.put(destination, new ArrayList<>()); } - sentMessages.get(destination).add(message); + sentMessages.get(destination).add(message.payload); } SettableFuture prepareSessionFuture = null; diff --git a/test/unit/org/apache/cassandra/repair/messages/RepairMessageSerializationsTest.java b/test/unit/org/apache/cassandra/repair/messages/RepairMessageSerializationsTest.java index d583d851fa9d..fa037a09af4d 100644 --- a/test/unit/org/apache/cassandra/repair/messages/RepairMessageSerializationsTest.java +++ b/test/unit/org/apache/cassandra/repair/messages/RepairMessageSerializationsTest.java @@ -118,7 +118,7 @@ private T serializeRoundTrip(T msg, IVersionedSerializ @Test public void validationCompleteMessage_NoMerkleTree() throws IOException { - ValidationComplete deserialized = validationCompleteMessage(null); + ValidationResponse deserialized = validationCompleteMessage(null); Assert.assertNull(deserialized.trees); } @@ -127,19 +127,19 @@ public void validationCompleteMessage_WithMerkleTree() throws IOException { MerkleTrees trees = new MerkleTrees(Murmur3Partitioner.instance); trees.addMerkleTree(256, new Range<>(new LongToken(1000), new LongToken(1001))); - ValidationComplete deserialized = validationCompleteMessage(trees); + ValidationResponse deserialized = validationCompleteMessage(trees); // a simple check to make sure we got some merkle trees back. Assert.assertEquals(trees.size(), deserialized.trees.size()); } - private ValidationComplete validationCompleteMessage(MerkleTrees trees) throws IOException + private ValidationResponse validationCompleteMessage(MerkleTrees trees) throws IOException { RepairJobDesc jobDesc = buildRepairJobDesc(); - ValidationComplete msg = trees == null ? - new ValidationComplete(jobDesc) : - new ValidationComplete(jobDesc, trees); - ValidationComplete deserialized = serializeRoundTrip(msg, ValidationComplete.serializer); + ValidationResponse msg = trees == null ? + new ValidationResponse(jobDesc) : + new ValidationResponse(jobDesc, trees); + ValidationResponse deserialized = serializeRoundTrip(msg, ValidationResponse.serializer); return deserialized; } @@ -164,8 +164,8 @@ public void syncCompleteMessage() throws IOException Lists.newArrayList(new StreamSummary(TableId.fromUUID(UUIDGen.getTimeUUID()), 5, 100)), Lists.newArrayList(new StreamSummary(TableId.fromUUID(UUIDGen.getTimeUUID()), 500, 10)) )); - SyncComplete msg = new SyncComplete(buildRepairJobDesc(), new SyncNodePair(src, dst), true, summaries); - serializeRoundTrip(msg, SyncComplete.serializer); + SyncResponse msg = new SyncResponse(buildRepairJobDesc(), new SyncNodePair(src, dst), true, summaries); + serializeRoundTrip(msg, SyncResponse.serializer); } @Test diff --git a/test/unit/org/apache/cassandra/repair/messages/RepairMessageSerializerTest.java b/test/unit/org/apache/cassandra/repair/messages/RepairMessageSerializerTest.java index d876139a3fe3..fedf498aeb99 100644 --- a/test/unit/org/apache/cassandra/repair/messages/RepairMessageSerializerTest.java +++ b/test/unit/org/apache/cassandra/repair/messages/RepairMessageSerializerTest.java @@ -24,6 +24,7 @@ import org.junit.Assert; import org.junit.Test; +import org.apache.cassandra.io.IVersionedSerializer; import org.apache.cassandra.io.util.DataInputBuffer; import org.apache.cassandra.io.util.DataOutputBuffer; import org.apache.cassandra.locator.InetAddressAndPort; @@ -37,16 +38,16 @@ public class RepairMessageSerializerTest { private static int MS_VERSION = MessagingService.current_version; - static RepairMessage serdes(RepairMessage message) + private static T serdes(IVersionedSerializer serializer, T message) { - int expectedSize = (int) RepairMessage.serializer.serializedSize(message, MS_VERSION); + int expectedSize = (int) serializer.serializedSize(message, MS_VERSION); try (DataOutputBuffer out = new DataOutputBuffer(expectedSize)) { - RepairMessage.serializer.serialize(message, out, MS_VERSION); + serializer.serialize(message, out, MS_VERSION); Assert.assertEquals(expectedSize, out.buffer().limit()); try (DataInputBuffer in = new DataInputBuffer(out.buffer(), false)) { - return RepairMessage.serializer.deserialize(in, MS_VERSION); + return serializer.deserialize(in, MS_VERSION); } } catch (IOException e) @@ -62,54 +63,50 @@ public void prepareConsistentRequest() throws Exception InetAddressAndPort peer1 = InetAddressAndPort.getByName("10.0.0.2"); InetAddressAndPort peer2 = InetAddressAndPort.getByName("10.0.0.3"); InetAddressAndPort peer3 = InetAddressAndPort.getByName("10.0.0.4"); - RepairMessage expected = new PrepareConsistentRequest(UUIDGen.getTimeUUID(), - coordinator, - Sets.newHashSet(peer1, peer2, peer3)); - RepairMessage actual = serdes(expected); + PrepareConsistentRequest expected = + new PrepareConsistentRequest(UUIDGen.getTimeUUID(), coordinator, Sets.newHashSet(peer1, peer2, peer3)); + PrepareConsistentRequest actual = serdes(PrepareConsistentRequest.serializer, expected); Assert.assertEquals(expected, actual); } @Test public void prepareConsistentResponse() throws Exception { - RepairMessage expected = new PrepareConsistentResponse(UUIDGen.getTimeUUID(), - InetAddressAndPort.getByName("10.0.0.2"), - true); - RepairMessage actual = serdes(expected); + PrepareConsistentResponse expected = + new PrepareConsistentResponse(UUIDGen.getTimeUUID(), InetAddressAndPort.getByName("10.0.0.2"), true); + PrepareConsistentResponse actual = serdes(PrepareConsistentResponse.serializer, expected); Assert.assertEquals(expected, actual); } @Test public void failSession() throws Exception { - RepairMessage expected = new FailSession(UUIDGen.getTimeUUID()); - RepairMessage actual = serdes(expected); + FailSession expected = new FailSession(UUIDGen.getTimeUUID()); + FailSession actual = serdes(FailSession.serializer, expected); Assert.assertEquals(expected, actual);; } @Test public void finalizeCommit() throws Exception { - RepairMessage expected = new FinalizeCommit(UUIDGen.getTimeUUID()); - RepairMessage actual = serdes(expected); + FinalizeCommit expected = new FinalizeCommit(UUIDGen.getTimeUUID()); + FinalizeCommit actual = serdes(FinalizeCommit.serializer, expected); Assert.assertEquals(expected, actual);; } @Test public void finalizePromise() throws Exception { - RepairMessage expected = new FinalizePromise(UUIDGen.getTimeUUID(), - InetAddressAndPort.getByName("10.0.0.2"), - true); - RepairMessage actual = serdes(expected); + FinalizePromise expected = new FinalizePromise(UUIDGen.getTimeUUID(), InetAddressAndPort.getByName("10.0.0.2"), true); + FinalizePromise actual = serdes(FinalizePromise.serializer, expected); Assert.assertEquals(expected, actual); } @Test public void finalizePropose() throws Exception { - RepairMessage expected = new FinalizePropose(UUIDGen.getTimeUUID()); - RepairMessage actual = serdes(expected); + FinalizePropose expected = new FinalizePropose(UUIDGen.getTimeUUID()); + FinalizePropose actual = serdes(FinalizePropose.serializer, expected); Assert.assertEquals(expected, actual);; } } diff --git a/test/unit/org/apache/cassandra/schema/MigrationManagerTest.java b/test/unit/org/apache/cassandra/schema/MigrationManagerTest.java index 5c709036c00d..f7dedea702d6 100644 --- a/test/unit/org/apache/cassandra/schema/MigrationManagerTest.java +++ b/test/unit/org/apache/cassandra/schema/MigrationManagerTest.java @@ -179,7 +179,7 @@ public void addNewTable() throws ConfigurationException // flush to exercise more than just hitting the memtable ColumnFamilyStore cfs = Keyspace.open(ksName).getColumnFamilyStore(tableName); assertNotNull(cfs); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); // and make sure we get out what we put in UntypedResultSet rows = QueryProcessor.executeInternal(String.format("SELECT * FROM %s.%s", ksName, tableName)); @@ -202,7 +202,7 @@ public void dropCf() throws ConfigurationException "dropCf", "col" + i, "anyvalue"); ColumnFamilyStore store = Keyspace.open(cfm.keyspace).getColumnFamilyStore(cfm.name); assertNotNull(store); - store.forceBlockingFlush(); + store.forceBlockingFlushToSSTable(); assertTrue(store.getDirectories().sstableLister(Directories.OnTxnErr.THROW).list().size() > 0); MigrationManager.announceTableDrop(ks.name, cfm.name, false); @@ -251,7 +251,7 @@ public void addNewKS() throws ConfigurationException "key0", "col0", "val0"); ColumnFamilyStore store = Keyspace.open(cfm.keyspace).getColumnFamilyStore(cfm.name); assertNotNull(store); - store.forceBlockingFlush(); + store.forceBlockingFlushToSSTable(); UntypedResultSet rows = QueryProcessor.executeInternal("SELECT * FROM newkeyspace1.newstandard1"); assertRows(rows, row("key0", "col0", "val0")); @@ -273,7 +273,7 @@ public void dropKS() throws ConfigurationException "dropKs", "col" + i, "anyvalue"); ColumnFamilyStore cfs = Keyspace.open(cfm.keyspace).getColumnFamilyStore(cfm.name); assertNotNull(cfs); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); assertTrue(!cfs.getDirectories().sstableLister(Directories.OnTxnErr.THROW).list().isEmpty()); MigrationManager.announceKeyspaceDrop(ks.name); @@ -354,7 +354,7 @@ public void createEmptyKsAddNewCf() throws ConfigurationException ColumnFamilyStore cfs = Keyspace.open(newKs.name).getColumnFamilyStore(newCf.name); assertNotNull(cfs); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); UntypedResultSet rows = QueryProcessor.executeInternal(String.format("SELECT * FROM %s.%s", EMPTY_KEYSPACE, tableName)); assertRows(rows, row("key0", "col0", "val0")); @@ -509,7 +509,7 @@ public void testDropIndex() throws ConfigurationException TABLE1i), "key0", "col0", 1L, 1L); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); ColumnFamilyStore indexCfs = cfs.indexManager.getIndexByName(indexName) .getBackingTable() .orElseThrow(throwAssert("Cannot access index cfs")); diff --git a/test/unit/org/apache/cassandra/security/SSLFactoryTest.java b/test/unit/org/apache/cassandra/security/SSLFactoryTest.java index 5fdbe7b40f4d..d2e85b222a4e 100644 --- a/test/unit/org/apache/cassandra/security/SSLFactoryTest.java +++ b/test/unit/org/apache/cassandra/security/SSLFactoryTest.java @@ -21,7 +21,6 @@ import java.io.File; import java.io.IOException; import java.security.cert.CertificateException; -import java.util.Arrays; import javax.net.ssl.TrustManagerFactory; import org.apache.commons.io.FileUtils; @@ -65,11 +64,11 @@ public class SSLFactoryTest @Before public void setup() { - encryptionOptions = new ServerEncryptionOptions(); - encryptionOptions.truststore = "test/conf/cassandra_ssl_test.truststore"; - encryptionOptions.truststore_password = "cassandra"; - encryptionOptions.require_client_auth = false; - encryptionOptions.cipher_suites = new String[] {"TLS_RSA_WITH_AES_128_CBC_SHA"}; + encryptionOptions = new ServerEncryptionOptions() + .withTrustStore("test/conf/cassandra_ssl_test.truststore") + .withTrustStorePassword("cassandra") + .withRequireClientAuth(false) + .withCipherSuites("TLS_RSA_WITH_AES_128_CBC_SHA"); SSLFactory.checkedExpiry = false; } @@ -108,28 +107,25 @@ public void getSslContext_JdkSsl() throws IOException SslContext sslContext = SSLFactory.getOrCreateSslContext(options, true, SSLFactory.SocketType.CLIENT, false); Assert.assertNotNull(sslContext); Assert.assertTrue(sslContext instanceof JdkSslContext); - Assert.assertEquals(Arrays.asList(encryptionOptions.cipher_suites), sslContext.cipherSuites()); + Assert.assertEquals(encryptionOptions.cipher_suites, sslContext.cipherSuites()); } - private EncryptionOptions addKeystoreOptions(EncryptionOptions options) + private ServerEncryptionOptions addKeystoreOptions(ServerEncryptionOptions options) { - options.keystore = "test/conf/cassandra_ssl_test.keystore"; - options.keystore_password = "cassandra"; - return options; + return options.withKeyStore("test/conf/cassandra_ssl_test.keystore") + .withKeyStorePassword("cassandra"); } @Test(expected = IOException.class) public void buildTrustManagerFactory_NoFile() throws IOException { - encryptionOptions.truststore = "/this/is/probably/not/a/file/on/your/test/machine"; - SSLFactory.buildTrustManagerFactory(encryptionOptions); + SSLFactory.buildTrustManagerFactory(encryptionOptions.withTrustStore("/this/is/probably/not/a/file/on/your/test/machine")); } @Test(expected = IOException.class) public void buildTrustManagerFactory_BadPassword() throws IOException { - encryptionOptions.truststore_password = "HomeOfBadPasswords"; - SSLFactory.buildTrustManagerFactory(encryptionOptions); + SSLFactory.buildTrustManagerFactory(encryptionOptions.withTrustStorePassword("HomeOfBadPasswords")); } @Test @@ -142,16 +138,16 @@ public void buildTrustManagerFactory_HappyPath() throws IOException @Test(expected = IOException.class) public void buildKeyManagerFactory_NoFile() throws IOException { - EncryptionOptions options = addKeystoreOptions(encryptionOptions); - options.keystore = "/this/is/probably/not/a/file/on/your/test/machine"; + EncryptionOptions options = addKeystoreOptions(encryptionOptions) + .withKeyStore("/this/is/probably/not/a/file/on/your/test/machine"); SSLFactory.buildKeyManagerFactory(options); } @Test(expected = IOException.class) public void buildKeyManagerFactory_BadPassword() throws IOException { - EncryptionOptions options = addKeystoreOptions(encryptionOptions); - encryptionOptions.keystore_password = "HomeOfBadPasswords"; + EncryptionOptions options = addKeystoreOptions(encryptionOptions) + .withKeyStorePassword("HomeOfBadPasswords"); SSLFactory.buildKeyManagerFactory(options); } @@ -169,8 +165,8 @@ public void testSslContextReload_HappyPath() throws IOException, InterruptedExce { try { - EncryptionOptions options = addKeystoreOptions(encryptionOptions); - options.enabled = true; + EncryptionOptions options = addKeystoreOptions(encryptionOptions) + .withEnabled(true); SSLFactory.initHotReloading((ServerEncryptionOptions) options, options, true); @@ -201,9 +197,9 @@ public void testSslContextReload_HappyPath() throws IOException, InterruptedExce @Test(expected = IOException.class) public void testSslFactorySslInit_BadPassword_ThrowsException() throws IOException { - EncryptionOptions options = addKeystoreOptions(encryptionOptions); - options.keystore_password = "bad password"; - options.enabled = true; + EncryptionOptions options = addKeystoreOptions(encryptionOptions) + .withKeyStorePassword("bad password") + .withEnabled(true); SSLFactory.initHotReloading((ServerEncryptionOptions) options, options, true); } @@ -213,10 +209,8 @@ public void testSslFactoryHotReload_BadPassword_DoesNotClearExistingSslContext() { try { - addKeystoreOptions(encryptionOptions); - - ServerEncryptionOptions options = new ServerEncryptionOptions(encryptionOptions); - options.enabled = true; + ServerEncryptionOptions options = addKeystoreOptions(encryptionOptions) + .withEnabled(true); SSLFactory.initHotReloading(options, options, true); SslContext oldCtx = SSLFactory.getOrCreateSslContext(options, true, SSLFactory.SocketType.CLIENT, OpenSsl @@ -226,8 +220,8 @@ public void testSslFactoryHotReload_BadPassword_DoesNotClearExistingSslContext() SSLFactory.checkCertFilesForHotReloading(options, options); keystoreFile.setLastModified(System.currentTimeMillis() + 5000); - ServerEncryptionOptions modOptions = new ServerEncryptionOptions(options); - modOptions.keystore_password = "bad password"; + ServerEncryptionOptions modOptions = new ServerEncryptionOptions(options) + .withKeyStorePassword("bad password"); SSLFactory.checkCertFilesForHotReloading(modOptions, modOptions); SslContext newCtx = SSLFactory.getOrCreateSslContext(options, true, SSLFactory.SocketType.CLIENT, OpenSsl .isAvailable()); @@ -241,29 +235,28 @@ public void testSslFactoryHotReload_BadPassword_DoesNotClearExistingSslContext() } @Test - public void testSslFactoryHotReload_CorruptOrNonExistentFile_DoesNotClearExistingSslContext() throws IOException, - InterruptedException + public void testSslFactoryHotReload_CorruptOrNonExistentFile_DoesNotClearExistingSslContext() throws IOException { try { - addKeystoreOptions(encryptionOptions); + ServerEncryptionOptions options = addKeystoreOptions(encryptionOptions); - File testKeystoreFile = new File(encryptionOptions.keystore + ".test"); - FileUtils.copyFile(new File(encryptionOptions.keystore),testKeystoreFile); - encryptionOptions.keystore = testKeystoreFile.getPath(); + File testKeystoreFile = new File(options.keystore + ".test"); + FileUtils.copyFile(new File(options.keystore),testKeystoreFile); + options = options + .withKeyStore(testKeystoreFile.getPath()) + .withEnabled(true); - EncryptionOptions options = new ServerEncryptionOptions(encryptionOptions); - options.enabled = true; - SSLFactory.initHotReloading((ServerEncryptionOptions) options, options, true); + SSLFactory.initHotReloading(options, options, true); SslContext oldCtx = SSLFactory.getOrCreateSslContext(options, true, SSLFactory.SocketType.CLIENT, OpenSsl .isAvailable()); - SSLFactory.checkCertFilesForHotReloading((ServerEncryptionOptions) options, options); + SSLFactory.checkCertFilesForHotReloading(options, options); testKeystoreFile.setLastModified(System.currentTimeMillis() + 15000); FileUtils.forceDelete(testKeystoreFile); - SSLFactory.checkCertFilesForHotReloading((ServerEncryptionOptions) options, options);; + SSLFactory.checkCertFilesForHotReloading(options, options);; SslContext newCtx = SSLFactory.getOrCreateSslContext(options, true, SSLFactory.SocketType.CLIENT, OpenSsl .isAvailable()); @@ -283,22 +276,22 @@ public void testSslFactoryHotReload_CorruptOrNonExistentFile_DoesNotClearExistin @Test public void getSslContext_ParamChanges() throws IOException { - EncryptionOptions options = addKeystoreOptions(encryptionOptions); - options.enabled = true; - options.cipher_suites = new String[]{ "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256" }; + EncryptionOptions options = addKeystoreOptions(encryptionOptions) + .withEnabled(true) + .withCipherSuites("TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256"); SslContext ctx1 = SSLFactory.getOrCreateSslContext(options, true, SSLFactory.SocketType.SERVER, OpenSsl.isAvailable()); Assert.assertTrue(ctx1.isServer()); - Assert.assertArrayEquals(ctx1.cipherSuites().toArray(), options.cipher_suites); + Assert.assertEquals(ctx1.cipherSuites(), options.cipher_suites); - options.cipher_suites = new String[]{ "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256" }; + options = options.withCipherSuites("TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256"); SslContext ctx2 = SSLFactory.getOrCreateSslContext(options, true, SSLFactory.SocketType.CLIENT, OpenSsl.isAvailable()); Assert.assertTrue(ctx2.isClient()); - Assert.assertArrayEquals(ctx2.cipherSuites().toArray(), options.cipher_suites); + Assert.assertEquals(ctx2.cipherSuites(), options.cipher_suites); } } diff --git a/test/unit/org/apache/cassandra/service/ActiveRepairServiceTest.java b/test/unit/org/apache/cassandra/service/ActiveRepairServiceTest.java index 4f7cde035acb..935dd19ec15c 100644 --- a/test/unit/org/apache/cassandra/service/ActiveRepairServiceTest.java +++ b/test/unit/org/apache/cassandra/service/ActiveRepairServiceTest.java @@ -304,7 +304,7 @@ private void createSSTables(ColumnFamilyStore cfs, int count) .build() .applyUnsafe(); } - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); } } diff --git a/test/unit/org/apache/cassandra/service/ClientWarningsTest.java b/test/unit/org/apache/cassandra/service/ClientWarningsTest.java index 3ae49edf96a2..ffa12d19766d 100644 --- a/test/unit/org/apache/cassandra/service/ClientWarningsTest.java +++ b/test/unit/org/apache/cassandra/service/ClientWarningsTest.java @@ -101,7 +101,7 @@ public void testTombstoneWarning() throws Exception client.execute(query); } ColumnFamilyStore store = Keyspace.open(KEYSPACE).getColumnFamilyStore(currentTable()); - store.forceBlockingFlush(); + store.forceBlockingFlushToSSTable(); for (int i = 0; i < iterations; i++) { @@ -111,7 +111,7 @@ public void testTombstoneWarning() throws Exception i), QueryOptions.DEFAULT); client.execute(query); } - store.forceBlockingFlush(); + store.forceBlockingFlushToSSTable(); { QueryMessage query = new QueryMessage(String.format("SELECT * FROM %s.%s WHERE pk = 1", diff --git a/test/unit/org/apache/cassandra/service/MoveTest.java b/test/unit/org/apache/cassandra/service/MoveTest.java index a7cfc1bddf32..a4da7b823906 100644 --- a/test/unit/org/apache/cassandra/service/MoveTest.java +++ b/test/unit/org/apache/cassandra/service/MoveTest.java @@ -496,7 +496,16 @@ private void finishMove(InetAddressAndPort host, int token, TokenMetadata tmd) { tmd.removeFromMoving(host); assertTrue(!tmd.isMoving(host)); - tmd.updateNormalToken(new BigIntegerToken(String.valueOf(token)), host); + Token newToken = new BigIntegerToken(String.valueOf(token)); + tmd.updateNormalToken(newToken, host); + // As well as upating TMD, update the host's tokens in gossip. Since CASSANDRA-15120, status changing to MOVING + // ensures that TMD is up to date with token assignments according to gossip. So we need to make sure gossip has + // the correct new token, as the moving node itself would do upon successful completion of the move operation. + // Without this, the next movement for that host will set the token in TMD's back to the old value from gossip + // and incorrect range movements will follow + Gossiper.instance.injectApplicationState(host, + ApplicationState.TOKENS, + new VersionedValue.VersionedValueFactory(partitioner).tokens(Collections.singleton(newToken))); } private Map.Entry, EndpointsForRange> generatePendingMapEntry(int start, int end, String... endpoints) throws UnknownHostException diff --git a/test/unit/org/apache/cassandra/service/NativeTransportServiceTest.java b/test/unit/org/apache/cassandra/service/NativeTransportServiceTest.java index c918fd6f1f38..86b73ab580fb 100644 --- a/test/unit/org/apache/cassandra/service/NativeTransportServiceTest.java +++ b/test/unit/org/apache/cassandra/service/NativeTransportServiceTest.java @@ -47,7 +47,7 @@ public static void setupDD() @After public void resetConfig() { - DatabaseDescriptor.getNativeProtocolEncryptionOptions().enabled = false; + DatabaseDescriptor.updateNativeProtocolEncryptionOptions(options -> options.withEnabled(false)); DatabaseDescriptor.setNativeTransportPortSSL(null); } @@ -85,8 +85,7 @@ public void testDestroy() { withService((NativeTransportService service) -> { BooleanSupplier allTerminated = () -> - service.getWorkerGroup().isShutdown() && service.getWorkerGroup().isTerminated() && - service.getEventExecutor().isShutdown() && service.getEventExecutor().isTerminated(); + service.getWorkerGroup().isShutdown() && service.getWorkerGroup().isTerminated(); assertFalse(allTerminated.getAsBoolean()); service.destroy(); assertTrue(allTerminated.getAsBoolean()); @@ -128,8 +127,8 @@ public void testPlainDefaultPort() public void testSSLOnly() { // default ssl settings: client encryption enabled and default native transport port used for ssl only - DatabaseDescriptor.getNativeProtocolEncryptionOptions().enabled = true; - DatabaseDescriptor.getNativeProtocolEncryptionOptions().optional = false; + DatabaseDescriptor.updateNativeProtocolEncryptionOptions(options -> options.withEnabled(true) + .withOptional(false)); withService((NativeTransportService service) -> { @@ -145,8 +144,8 @@ public void testSSLOnly() public void testSSLOptional() { // default ssl settings: client encryption enabled and default native transport port used for optional ssl - DatabaseDescriptor.getNativeProtocolEncryptionOptions().enabled = true; - DatabaseDescriptor.getNativeProtocolEncryptionOptions().optional = true; + DatabaseDescriptor.updateNativeProtocolEncryptionOptions(options -> options.withEnabled(true) + .withOptional(true)); withService((NativeTransportService service) -> { @@ -162,7 +161,7 @@ public void testSSLOptional() public void testSSLWithNonSSL() { // ssl+non-ssl settings: client encryption enabled and additional ssl port specified - DatabaseDescriptor.getNativeProtocolEncryptionOptions().enabled = true; + DatabaseDescriptor.updateNativeProtocolEncryptionOptions(options -> options.withEnabled(true)); DatabaseDescriptor.setNativeTransportPortSSL(8432); withService((NativeTransportService service) -> diff --git a/test/unit/org/apache/cassandra/service/QueryPagerTest.java b/test/unit/org/apache/cassandra/service/QueryPagerTest.java index 407efc6261d0..0b8248a3197c 100644 --- a/test/unit/org/apache/cassandra/service/QueryPagerTest.java +++ b/test/unit/org/apache/cassandra/service/QueryPagerTest.java @@ -246,9 +246,10 @@ public void namesQueryTest() throws Exception public void sliceQueryTest() throws Exception { sliceQueryTest(false, ProtocolVersion.V3); - sliceQueryTest(true, ProtocolVersion.V4); - sliceQueryTest(false, ProtocolVersion.V3); - sliceQueryTest(true, ProtocolVersion.V4); + sliceQueryTest(true, ProtocolVersion.V3); + + sliceQueryTest(false, ProtocolVersion.V4); + sliceQueryTest(true, ProtocolVersion.V4); } public void sliceQueryTest(boolean testPagingState, ProtocolVersion protocolVersion) throws Exception @@ -279,9 +280,10 @@ public void sliceQueryTest(boolean testPagingState, ProtocolVersion protocolVers public void reversedSliceQueryTest() throws Exception { reversedSliceQueryTest(false, ProtocolVersion.V3); - reversedSliceQueryTest(true, ProtocolVersion.V4); - reversedSliceQueryTest(false, ProtocolVersion.V3); - reversedSliceQueryTest(true, ProtocolVersion.V4); + reversedSliceQueryTest(true, ProtocolVersion.V3); + + reversedSliceQueryTest(false, ProtocolVersion.V4); + reversedSliceQueryTest(true, ProtocolVersion.V4); } public void reversedSliceQueryTest(boolean testPagingState, ProtocolVersion protocolVersion) throws Exception @@ -312,9 +314,10 @@ public void reversedSliceQueryTest(boolean testPagingState, ProtocolVersion prot public void multiQueryTest() throws Exception { multiQueryTest(false, ProtocolVersion.V3); - multiQueryTest(true, ProtocolVersion.V4); - multiQueryTest(false, ProtocolVersion.V3); - multiQueryTest(true, ProtocolVersion.V4); + multiQueryTest(true, ProtocolVersion.V3); + + multiQueryTest(false, ProtocolVersion.V4); + multiQueryTest(true, ProtocolVersion.V4); } public void multiQueryTest(boolean testPagingState, ProtocolVersion protocolVersion) throws Exception @@ -350,9 +353,10 @@ public void multiQueryTest(boolean testPagingState, ProtocolVersion protocolVers public void rangeNamesQueryTest() throws Exception { rangeNamesQueryTest(false, ProtocolVersion.V3); - rangeNamesQueryTest(true, ProtocolVersion.V4); - rangeNamesQueryTest(false, ProtocolVersion.V3); - rangeNamesQueryTest(true, ProtocolVersion.V4); + rangeNamesQueryTest(true, ProtocolVersion.V3); + + rangeNamesQueryTest(false, ProtocolVersion.V4); + rangeNamesQueryTest(true, ProtocolVersion.V4); } public void rangeNamesQueryTest(boolean testPagingState, ProtocolVersion protocolVersion) throws Exception @@ -379,9 +383,10 @@ public void rangeNamesQueryTest(boolean testPagingState, ProtocolVersion protoco public void rangeSliceQueryTest() throws Exception { rangeSliceQueryTest(false, ProtocolVersion.V3); - rangeSliceQueryTest(true, ProtocolVersion.V4); - rangeSliceQueryTest(false, ProtocolVersion.V3); - rangeSliceQueryTest(true, ProtocolVersion.V4); + rangeSliceQueryTest(true, ProtocolVersion.V3); + + rangeSliceQueryTest(false, ProtocolVersion.V4); + rangeSliceQueryTest(true, ProtocolVersion.V4); } public void rangeSliceQueryTest(boolean testPagingState, ProtocolVersion protocolVersion) throws Exception diff --git a/test/unit/org/apache/cassandra/service/RemoveTest.java b/test/unit/org/apache/cassandra/service/RemoveTest.java index 0d39322ef678..e6fbe7beedeb 100644 --- a/test/unit/org/apache/cassandra/service/RemoveTest.java +++ b/test/unit/org/apache/cassandra/service/RemoveTest.java @@ -41,10 +41,12 @@ import org.apache.cassandra.gms.VersionedValue.VersionedValueFactory; import org.apache.cassandra.locator.InetAddressAndPort; import org.apache.cassandra.locator.TokenMetadata; -import org.apache.cassandra.net.MessageOut; +import org.apache.cassandra.net.Message; import org.apache.cassandra.net.MessagingService; import org.apache.cassandra.utils.FBUtilities; +import static org.apache.cassandra.net.NoPayload.noPayload; +import static org.apache.cassandra.net.Verb.REPLICATION_DONE_REQ; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; @@ -96,8 +98,9 @@ public void setup() throws IOException, ConfigurationException @After public void tearDown() { - MessagingService.instance().clearMessageSinks(); - MessagingService.instance().clearCallbacksUnsafe(); + MessagingService.instance().inboundSink.clear(); + MessagingService.instance().outboundSink.clear(); + MessagingService.instance().callbacks.unsafeClear(); } @Test(expected = UnsupportedOperationException.class) @@ -161,8 +164,10 @@ public void testRemoveHostId() throws InterruptedException for (InetAddressAndPort host : hosts) { - MessageOut msg = new MessageOut(host, MessagingService.Verb.REPLICATION_FINISHED, null, null, Collections.emptyList(), null); - MessagingService.instance().sendRR(msg, FBUtilities.getBroadcastAddressAndPort()); + Message msg = Message.builder(REPLICATION_DONE_REQ, noPayload) + .from(host) + .build(); + MessagingService.instance().send(msg, FBUtilities.getBroadcastAddressAndPort()); } remover.join(); diff --git a/test/unit/org/apache/cassandra/service/SerializationsTest.java b/test/unit/org/apache/cassandra/service/SerializationsTest.java index 12236831abc0..0a5a023a48e0 100644 --- a/test/unit/org/apache/cassandra/service/SerializationsTest.java +++ b/test/unit/org/apache/cassandra/service/SerializationsTest.java @@ -38,10 +38,10 @@ import org.apache.cassandra.dht.RandomPartitioner; import org.apache.cassandra.dht.Range; import org.apache.cassandra.dht.Token; +import org.apache.cassandra.io.IVersionedSerializer; import org.apache.cassandra.io.util.DataInputPlus.DataInputStreamPlus; import org.apache.cassandra.io.util.DataOutputStreamPlus; import org.apache.cassandra.locator.InetAddressAndPort; -import org.apache.cassandra.net.MessageIn; import org.apache.cassandra.repair.SyncNodePair; import org.apache.cassandra.repair.RepairJobDesc; import org.apache.cassandra.repair.Validator; @@ -79,25 +79,22 @@ public static void tearDown() partitionerSwitcher.close(); } - private void testRepairMessageWrite(String fileName, RepairMessage... messages) throws IOException + private void testRepairMessageWrite(String fileName, IVersionedSerializer serializer, T... messages) throws IOException { try (DataOutputStreamPlus out = getOutput(fileName)) { - for (RepairMessage message : messages) + for (T message : messages) { - testSerializedSize(message, RepairMessage.serializer); - RepairMessage.serializer.serialize(message, out, getVersion()); + testSerializedSize(message, serializer); + serializer.serialize(message, out, getVersion()); } - // also serialize MessageOut - for (RepairMessage message : messages) - message.createMessage().serialize(out, getVersion()); } } private void testValidationRequestWrite() throws IOException { ValidationRequest message = new ValidationRequest(DESC, 1234); - testRepairMessageWrite("service.ValidationRequest.bin", message); + testRepairMessageWrite("service.ValidationRequest.bin", ValidationRequest.serializer, message); } @Test @@ -108,12 +105,9 @@ public void testValidationRequestRead() throws IOException try (DataInputStreamPlus in = getInput("service.ValidationRequest.bin")) { - RepairMessage message = RepairMessage.serializer.deserialize(in, getVersion()); - assert message.messageType == RepairMessage.Type.VALIDATION_REQUEST; + ValidationRequest message = ValidationRequest.serializer.deserialize(in, getVersion()); assert DESC.equals(message.desc); - assert ((ValidationRequest) message).nowInSec == 1234; - - assert MessageIn.read(in, getVersion(), -1) != null; + assert message.nowInSec == 1234; } } @@ -126,7 +120,7 @@ private void testValidationCompleteWrite() throws IOException // empty validation mt.addMerkleTree((int) Math.pow(2, 15), FULL_RANGE); Validator v0 = new Validator(DESC, FBUtilities.getBroadcastAddressAndPort(), -1, PreviewKind.NONE); - ValidationComplete c0 = new ValidationComplete(DESC, mt); + ValidationResponse c0 = new ValidationResponse(DESC, mt); // validation with a tree mt = new MerkleTrees(p); @@ -134,12 +128,12 @@ private void testValidationCompleteWrite() throws IOException for (int i = 0; i < 10; i++) mt.split(p.getRandomToken()); Validator v1 = new Validator(DESC, FBUtilities.getBroadcastAddressAndPort(), -1, PreviewKind.NONE); - ValidationComplete c1 = new ValidationComplete(DESC, mt); + ValidationResponse c1 = new ValidationResponse(DESC, mt); // validation failed - ValidationComplete c3 = new ValidationComplete(DESC); + ValidationResponse c3 = new ValidationResponse(DESC); - testRepairMessageWrite("service.ValidationComplete.bin", c0, c1, c3); + testRepairMessageWrite("service.ValidationComplete.bin", ValidationResponse.serializer, c0, c1, c3); } @Test @@ -151,32 +145,25 @@ public void testValidationCompleteRead() throws IOException try (DataInputStreamPlus in = getInput("service.ValidationComplete.bin")) { // empty validation - RepairMessage message = RepairMessage.serializer.deserialize(in, getVersion()); - assert message.messageType == RepairMessage.Type.VALIDATION_COMPLETE; + ValidationResponse message = ValidationResponse.serializer.deserialize(in, getVersion()); assert DESC.equals(message.desc); - assert ((ValidationComplete) message).success(); - assert ((ValidationComplete) message).trees != null; + assert message.success(); + assert message.trees != null; // validation with a tree - message = RepairMessage.serializer.deserialize(in, getVersion()); - assert message.messageType == RepairMessage.Type.VALIDATION_COMPLETE; + message = ValidationResponse.serializer.deserialize(in, getVersion()); assert DESC.equals(message.desc); - assert ((ValidationComplete) message).success(); - assert ((ValidationComplete) message).trees != null; + assert message.success(); + assert message.trees != null; // failed validation - message = RepairMessage.serializer.deserialize(in, getVersion()); - assert message.messageType == RepairMessage.Type.VALIDATION_COMPLETE; + message = ValidationResponse.serializer.deserialize(in, getVersion()); assert DESC.equals(message.desc); - assert !((ValidationComplete) message).success(); - assert ((ValidationComplete) message).trees == null; - - // MessageOuts - for (int i = 0; i < 3; i++) - assert MessageIn.read(in, getVersion(), -1) != null; + assert !message.success(); + assert message.trees == null; } } @@ -187,7 +174,7 @@ private void testSyncRequestWrite() throws IOException InetAddressAndPort dest = InetAddressAndPort.getByNameOverrideDefaults("127.0.0.3", PORT); SyncRequest message = new SyncRequest(DESC, local, src, dest, Collections.singleton(FULL_RANGE), PreviewKind.NONE); - testRepairMessageWrite("service.SyncRequest.bin", message); + testRepairMessageWrite("service.SyncRequest.bin", SyncRequest.serializer, message); } @Test @@ -202,15 +189,12 @@ public void testSyncRequestRead() throws IOException try (DataInputStreamPlus in = getInput("service.SyncRequest.bin")) { - RepairMessage message = RepairMessage.serializer.deserialize(in, getVersion()); - assert message.messageType == RepairMessage.Type.SYNC_REQUEST; + SyncRequest message = SyncRequest.serializer.deserialize(in, getVersion()); assert DESC.equals(message.desc); - assert local.equals(((SyncRequest) message).initiator); - assert src.equals(((SyncRequest) message).src); - assert dest.equals(((SyncRequest) message).dst); - assert ((SyncRequest) message).ranges.size() == 1 && ((SyncRequest) message).ranges.contains(FULL_RANGE); - - assert MessageIn.read(in, getVersion(), -1) != null; + assert local.equals(message.initiator); + assert src.equals(message.src); + assert dest.equals(message.dst); + assert message.ranges.size() == 1 && message.ranges.contains(FULL_RANGE); } } @@ -224,11 +208,11 @@ private void testSyncCompleteWrite() throws IOException Lists.newArrayList(new StreamSummary(TableId.fromUUID(UUIDGen.getTimeUUID()), 5, 100)), Lists.newArrayList(new StreamSummary(TableId.fromUUID(UUIDGen.getTimeUUID()), 500, 10)) )); - SyncComplete success = new SyncComplete(DESC, src, dest, true, summaries); + SyncResponse success = new SyncResponse(DESC, src, dest, true, summaries); // sync fail - SyncComplete fail = new SyncComplete(DESC, src, dest, false, Collections.emptyList()); + SyncResponse fail = new SyncResponse(DESC, src, dest, false, Collections.emptyList()); - testRepairMessageWrite("service.SyncComplete.bin", success, fail); + testRepairMessageWrite("service.SyncComplete.bin", SyncResponse.serializer, success, fail); } @Test @@ -244,26 +228,20 @@ public void testSyncCompleteRead() throws IOException try (DataInputStreamPlus in = getInput("service.SyncComplete.bin")) { // success - RepairMessage message = RepairMessage.serializer.deserialize(in, getVersion()); - assert message.messageType == RepairMessage.Type.SYNC_COMPLETE; + SyncResponse message = SyncResponse.serializer.deserialize(in, getVersion()); assert DESC.equals(message.desc); System.out.println(nodes); - System.out.println(((SyncComplete) message).nodes); - assert nodes.equals(((SyncComplete) message).nodes); - assert ((SyncComplete) message).success; + System.out.println(message.nodes); + assert nodes.equals(message.nodes); + assert message.success; // fail - message = RepairMessage.serializer.deserialize(in, getVersion()); - assert message.messageType == RepairMessage.Type.SYNC_COMPLETE; + message = SyncResponse.serializer.deserialize(in, getVersion()); assert DESC.equals(message.desc); - assert nodes.equals(((SyncComplete) message).nodes); - assert !((SyncComplete) message).success; - - // MessageOuts - for (int i = 0; i < 2; i++) - assert MessageIn.read(in, getVersion(), -1) != null; + assert nodes.equals(message.nodes); + assert !message.success; } } } diff --git a/test/unit/org/apache/cassandra/service/StorageServiceServerTest.java b/test/unit/org/apache/cassandra/service/StorageServiceServerTest.java index 2db221bff8c0..565d91a7f9e3 100644 --- a/test/unit/org/apache/cassandra/service/StorageServiceServerTest.java +++ b/test/unit/org/apache/cassandra/service/StorageServiceServerTest.java @@ -28,6 +28,7 @@ import com.google.common.collect.HashMultimap; import com.google.common.collect.Multimap; +import org.junit.Assert; import org.junit.BeforeClass; import org.junit.Test; import org.junit.runner.RunWith; @@ -621,12 +622,20 @@ public void testGetNativeAddress() throws Exception assertEquals("127.0.0.3:666", StorageService.instance.getNativeaddress(internalAddress, true)); } - @Test(expected = IllegalStateException.class) + @Test public void testAuditLogEnableLoggerNotFound() throws Exception { StorageService.instance.enableAuditLog(null, null, null, null, null, null, null); assertTrue(AuditLogManager.getInstance().isAuditingEnabled()); - StorageService.instance.enableAuditLog("foobar", null, null, null, null, null, null); + try + { + StorageService.instance.enableAuditLog("foobar", null, null, null, null, null, null); + Assert.fail(); + } + catch (IllegalStateException ex) + { + StorageService.instance.disableAuditLog(); + } } @Test @@ -646,5 +655,7 @@ public void testAuditLogEnableLoggerTransitions() throws Exception StorageService.instance.enableAuditLog(null, null, null, null, null, null, null); assertTrue(AuditLogManager.getInstance().isAuditingEnabled()); + + StorageService.instance.disableAuditLog(); } } diff --git a/test/unit/org/apache/cassandra/service/WriteResponseHandlerTest.java b/test/unit/org/apache/cassandra/service/WriteResponseHandlerTest.java index 2c186ba98d90..f06b706c68f5 100644 --- a/test/unit/org/apache/cassandra/service/WriteResponseHandlerTest.java +++ b/test/unit/org/apache/cassandra/service/WriteResponseHandlerTest.java @@ -44,10 +44,12 @@ import org.apache.cassandra.locator.ReplicaCollection; import org.apache.cassandra.locator.ReplicaUtils; import org.apache.cassandra.locator.TokenMetadata; -import org.apache.cassandra.net.MessageIn; +import org.apache.cassandra.net.Message; +import org.apache.cassandra.net.Verb; import org.apache.cassandra.schema.KeyspaceParams; import org.apache.cassandra.utils.ByteBufferUtil; +import static org.apache.cassandra.net.NoPayload.noPayload; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; @@ -146,11 +148,11 @@ public void idealCLLatencyTracked() throws Throwable AbstractWriteResponseHandler awr = createWriteResponseHandler(ConsistencyLevel.LOCAL_QUORUM, ConsistencyLevel.EACH_QUORUM, System.nanoTime() - TimeUnit.DAYS.toNanos(1)); //dc1 - awr.response(createDummyMessage(0)); - awr.response(createDummyMessage(1)); + awr.onResponse(createDummyMessage(0)); + awr.onResponse(createDummyMessage(1)); //dc2 - awr.response(createDummyMessage(4)); - awr.response(createDummyMessage(5)); + awr.onResponse(createDummyMessage(4)); + awr.onResponse(createDummyMessage(5)); //Don't need the others awr.expired(); @@ -172,13 +174,13 @@ public void idealCLWriteResponeHandlerWorks() throws Throwable AbstractWriteResponseHandler awr = createWriteResponseHandler(ConsistencyLevel.LOCAL_QUORUM, ConsistencyLevel.ALL); //dc1 - awr.response(createDummyMessage(0)); - awr.response(createDummyMessage(1)); - awr.response(createDummyMessage(2)); + awr.onResponse(createDummyMessage(0)); + awr.onResponse(createDummyMessage(1)); + awr.onResponse(createDummyMessage(2)); //dc2 - awr.response(createDummyMessage(3)); - awr.response(createDummyMessage(4)); - awr.response(createDummyMessage(5)); + awr.onResponse(createDummyMessage(3)); + awr.onResponse(createDummyMessage(4)); + awr.onResponse(createDummyMessage(5)); assertEquals(0, ks.metric.writeFailedIdealCL.getCount()); assertEquals(startingCount + 1, ks.metric.idealCLWriteLatency.latency.getCount()); @@ -195,13 +197,13 @@ public void idealCLDatacenterWriteResponeHandlerWorks() throws Throwable AbstractWriteResponseHandler awr = createWriteResponseHandler(ConsistencyLevel.ONE, ConsistencyLevel.LOCAL_QUORUM); //dc1 - awr.response(createDummyMessage(0)); - awr.response(createDummyMessage(1)); - awr.response(createDummyMessage(2)); + awr.onResponse(createDummyMessage(0)); + awr.onResponse(createDummyMessage(1)); + awr.onResponse(createDummyMessage(2)); //dc2 - awr.response(createDummyMessage(3)); - awr.response(createDummyMessage(4)); - awr.response(createDummyMessage(5)); + awr.onResponse(createDummyMessage(3)); + awr.onResponse(createDummyMessage(4)); + awr.onResponse(createDummyMessage(5)); assertEquals(0, ks.metric.writeFailedIdealCL.getCount()); assertEquals(startingCount + 1, ks.metric.idealCLWriteLatency.latency.getCount()); @@ -218,9 +220,9 @@ public void failedIdealCLIncrementsStat() throws Throwable AbstractWriteResponseHandler awr = createWriteResponseHandler(ConsistencyLevel.LOCAL_QUORUM, ConsistencyLevel.EACH_QUORUM); //Succeed in local DC - awr.response(createDummyMessage(0)); - awr.response(createDummyMessage(1)); - awr.response(createDummyMessage(2)); + awr.onResponse(createDummyMessage(0)); + awr.onResponse(createDummyMessage(1)); + awr.onResponse(createDummyMessage(2)); //Fail in remote DC awr.expired(); @@ -241,8 +243,10 @@ private static AbstractWriteResponseHandler createWriteResponseHandler(Consisten null, WriteType.SIMPLE, queryStartTime, ideal); } - private static MessageIn createDummyMessage(int target) + private static Message createDummyMessage(int target) { - return MessageIn.create(targets.get(target).endpoint(), null, null, null, 0, 0L); + return Message.builder(Verb.ECHO_REQ, noPayload) + .from(targets.get(target).endpoint()) + .build(); } } diff --git a/test/unit/org/apache/cassandra/service/pager/PagingStateTest.java b/test/unit/org/apache/cassandra/service/pager/PagingStateTest.java index 778088d274b8..58e863522549 100644 --- a/test/unit/org/apache/cassandra/service/pager/PagingStateTest.java +++ b/test/unit/org/apache/cassandra/service/pager/PagingStateTest.java @@ -1,4 +1,3 @@ - /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file @@ -19,6 +18,7 @@ */ package org.apache.cassandra.service.pager; +import java.io.IOException; import java.nio.ByteBuffer; import org.junit.BeforeClass; @@ -65,7 +65,7 @@ public void testSerializationBackwardCompatibility() } @Test - public void testSerializeDeserializeV3() + public void testSerializeV3DeserializeV3() { PagingState state = Util.makeSomePagingState(ProtocolVersion.V3); ByteBuffer serialized = state.serialize(ProtocolVersion.V3); @@ -74,11 +74,47 @@ public void testSerializeDeserializeV3() } @Test - public void testSerializeDeserializeV4() + public void testSerializeV4DeserializeV4() { PagingState state = Util.makeSomePagingState(ProtocolVersion.V4); ByteBuffer serialized = state.serialize(ProtocolVersion.V4); assertEquals(serialized.remaining(), state.serializedSize(ProtocolVersion.V4)); assertEquals(state, PagingState.deserialize(serialized, ProtocolVersion.V4)); } + + @Test + public void testSerializeV3DeserializeV4() + { + PagingState state = Util.makeSomePagingState(ProtocolVersion.V3); + ByteBuffer serialized = state.serialize(ProtocolVersion.V3); + assertEquals(serialized.remaining(), state.serializedSize(ProtocolVersion.V3)); + assertEquals(state, PagingState.deserialize(serialized, ProtocolVersion.V4)); + } + + @Test + public void testSerializeV4DeserializeV3() + { + PagingState state = Util.makeSomePagingState(ProtocolVersion.V4); + ByteBuffer serialized = state.serialize(ProtocolVersion.V4); + assertEquals(serialized.remaining(), state.serializedSize(ProtocolVersion.V4)); + assertEquals(state, PagingState.deserialize(serialized, ProtocolVersion.V3)); + } + + @Test + public void testSerializeV3WithoutRemainingInPartitionDeserializeV3() throws IOException + { + PagingState state = Util.makeSomePagingState(ProtocolVersion.V3, Integer.MAX_VALUE); + ByteBuffer serialized = state.legacySerialize(false); + assertEquals(serialized.remaining(), state.legacySerializedSize(false)); + assertEquals(state, PagingState.deserialize(serialized, ProtocolVersion.V3)); + } + + @Test + public void testSerializeV3WithoutRemainingInPartitionDeserializeV4() throws IOException + { + PagingState state = Util.makeSomePagingState(ProtocolVersion.V3, Integer.MAX_VALUE); + ByteBuffer serialized = state.legacySerialize(false); + assertEquals(serialized.remaining(), state.legacySerializedSize(false)); + assertEquals(state, PagingState.deserialize(serialized, ProtocolVersion.V4)); + } } diff --git a/test/unit/org/apache/cassandra/service/reads/AbstractReadResponseTest.java b/test/unit/org/apache/cassandra/service/reads/AbstractReadResponseTest.java index 582aff8bf67a..545731b87d3d 100644 --- a/test/unit/org/apache/cassandra/service/reads/AbstractReadResponseTest.java +++ b/test/unit/org/apache/cassandra/service/reads/AbstractReadResponseTest.java @@ -65,7 +65,7 @@ import org.apache.cassandra.db.rows.UnfilteredRowIterator; import org.apache.cassandra.dht.Murmur3Partitioner; import org.apache.cassandra.locator.InetAddressAndPort; -import org.apache.cassandra.net.MessageIn; +import org.apache.cassandra.net.Message; import org.apache.cassandra.net.MessagingService; import org.apache.cassandra.schema.ColumnMetadata; import org.apache.cassandra.schema.KeyspaceParams; @@ -73,6 +73,8 @@ import org.apache.cassandra.utils.ByteBufferUtil; import org.apache.cassandra.utils.FBUtilities; +import static org.apache.cassandra.net.Verb.READ_REQ; + /** * Base class for testing various components which deal with read responses */ @@ -216,7 +218,7 @@ static DecoratedKey dk(int k) } - static MessageIn response(ReadCommand command, + static Message response(ReadCommand command, InetAddressAndPort from, UnfilteredPartitionIterator data, boolean isDigestResponse, @@ -227,10 +229,12 @@ static MessageIn response(ReadCommand command, ReadResponse response = isDigestResponse ? ReadResponse.createDigestResponse(data, command) : ReadResponse.createRemoteDataResponse(data, repairedDataDigest, hasPendingRepair, command, fromVersion); - return MessageIn.create(from, response, Collections.emptyMap(), MessagingService.Verb.READ, fromVersion); + return Message.builder(READ_REQ, response) + .from(from) + .build(); } - static MessageIn response(InetAddressAndPort from, + static Message response(InetAddressAndPort from, UnfilteredPartitionIterator partitionIterator, ByteBuffer repairedDigest, boolean hasPendingRepair, @@ -239,12 +243,12 @@ static MessageIn response(InetAddressAndPort from, return response(cmd, from, partitionIterator, false, MessagingService.current_version, repairedDigest, hasPendingRepair); } - static MessageIn response(ReadCommand command, InetAddressAndPort from, UnfilteredPartitionIterator data, boolean isDigestResponse) + static Message response(ReadCommand command, InetAddressAndPort from, UnfilteredPartitionIterator data, boolean isDigestResponse) { return response(command, from, data, false, MessagingService.current_version, ByteBufferUtil.EMPTY_BYTE_BUFFER, isDigestResponse); } - static MessageIn response(ReadCommand command, InetAddressAndPort from, UnfilteredPartitionIterator data) + static Message response(ReadCommand command, InetAddressAndPort from, UnfilteredPartitionIterator data) { return response(command, from, data, false, MessagingService.current_version, ByteBufferUtil.EMPTY_BYTE_BUFFER, false); } diff --git a/test/unit/org/apache/cassandra/service/reads/ReadExecutorTest.java b/test/unit/org/apache/cassandra/service/reads/ReadExecutorTest.java index 34be5ee5ed68..e0a59276ae5b 100644 --- a/test/unit/org/apache/cassandra/service/reads/ReadExecutorTest.java +++ b/test/unit/org/apache/cassandra/service/reads/ReadExecutorTest.java @@ -20,10 +20,8 @@ import java.util.concurrent.TimeUnit; -import org.apache.cassandra.config.DatabaseDescriptor; import org.apache.cassandra.dht.Murmur3Partitioner; import org.apache.cassandra.dht.Token; -import org.apache.cassandra.locator.EndpointsForRange; import org.apache.cassandra.locator.ReplicaPlan; import org.junit.Before; import org.junit.BeforeClass; @@ -40,12 +38,12 @@ import org.apache.cassandra.exceptions.RequestFailureReason; import org.apache.cassandra.locator.EndpointsForToken; import org.apache.cassandra.locator.InetAddressAndPort; -import org.apache.cassandra.locator.ReplicaLayout; -import org.apache.cassandra.locator.ReplicaUtils; -import org.apache.cassandra.net.MessageOut; -import org.apache.cassandra.net.MessagingService; +import org.apache.cassandra.net.Message; +import org.apache.cassandra.net.NoPayload; +import org.apache.cassandra.net.Verb; import org.apache.cassandra.schema.KeyspaceParams; +import static java.util.concurrent.TimeUnit.MILLISECONDS; import static org.apache.cassandra.locator.ReplicaUtils.full; import static org.junit.Assert.assertEquals; import static org.junit.Assert.fail; @@ -205,24 +203,16 @@ public static class MockSinglePartitionReadCommand extends SinglePartitionReadCo } @Override - public long getTimeout() + public long getTimeout(TimeUnit unit) { - return timeout; + return unit.convert(timeout, MILLISECONDS); } @Override - public MessageOut createMessage() + public Message createMessage(boolean trackRepairedData) { - return new MessageOut(MessagingService.Verb.BATCH_REMOVE) - { - @Override - public int serializedSize(int version) - { - return 0; - } - }; + return Message.out(Verb.ECHO_REQ, NoPayload.noPayload); } - } private ReplicaPlan.ForTokenRead plan(EndpointsForToken targets, ConsistencyLevel consistencyLevel) diff --git a/test/unit/org/apache/cassandra/service/reads/repair/AbstractReadRepairTest.java b/test/unit/org/apache/cassandra/service/reads/repair/AbstractReadRepairTest.java index 68afedcf7696..3d3973273cab 100644 --- a/test/unit/org/apache/cassandra/service/reads/repair/AbstractReadRepairTest.java +++ b/test/unit/org/apache/cassandra/service/reads/repair/AbstractReadRepairTest.java @@ -1,6 +1,5 @@ package org.apache.cassandra.service.reads.repair; -import java.util.Collections; import java.util.List; import java.util.Set; import java.util.concurrent.TimeUnit; @@ -42,10 +41,8 @@ import org.apache.cassandra.locator.EndpointsForRange; import org.apache.cassandra.locator.InetAddressAndPort; import org.apache.cassandra.locator.Replica; -import org.apache.cassandra.locator.ReplicaLayout; import org.apache.cassandra.locator.ReplicaUtils; -import org.apache.cassandra.net.MessageIn; -import org.apache.cassandra.net.MessagingService; +import org.apache.cassandra.net.Message; import org.apache.cassandra.schema.KeyspaceMetadata; import org.apache.cassandra.schema.KeyspaceParams; import org.apache.cassandra.schema.MigrationManager; @@ -55,6 +52,7 @@ import static org.apache.cassandra.locator.Replica.fullReplica; import static org.apache.cassandra.locator.ReplicaUtils.FULL_RANGE; +import static org.apache.cassandra.net.Verb.INTERNAL_RSP; @Ignore public abstract class AbstractReadRepairTest @@ -163,14 +161,12 @@ static Mutation mutation(Cell... cells) } @SuppressWarnings("resource") - static MessageIn msg(InetAddressAndPort from, Cell... cells) + static Message msg(InetAddressAndPort from, Cell... cells) { UnfilteredPartitionIterator iter = new SingletonUnfilteredPartitionIterator(update(cells).unfilteredIterator()); - return MessageIn.create(from, - ReadResponse.createDataResponse(iter, command), - Collections.emptyMap(), - MessagingService.Verb.INTERNAL_RESPONSE, - MessagingService.current_version); + return Message.builder(INTERNAL_RSP, ReadResponse.createDataResponse(iter, command)) + .from(from) + .build(); } static class ResultConsumer implements Consumer @@ -306,8 +302,8 @@ public void noSpeculationRequired() repair.startRepair(null, consumer); Assert.assertEquals(epSet(target1, target2), repair.getReadRecipients()); - repair.getReadCallback().response(msg(target1, cell1)); - repair.getReadCallback().response(msg(target2, cell1)); + repair.getReadCallback().onResponse(msg(target1, cell1)); + repair.getReadCallback().onResponse(msg(target2, cell1)); repair.maybeSendAdditionalReads(); Assert.assertEquals(epSet(target1, target2), repair.getReadRecipients()); diff --git a/test/unit/org/apache/cassandra/service/reads/repair/BlockingReadRepairTest.java b/test/unit/org/apache/cassandra/service/reads/repair/BlockingReadRepairTest.java index 6ea593deb520..7538832add9b 100644 --- a/test/unit/org/apache/cassandra/service/reads/repair/BlockingReadRepairTest.java +++ b/test/unit/org/apache/cassandra/service/reads/repair/BlockingReadRepairTest.java @@ -39,7 +39,7 @@ import org.apache.cassandra.locator.InetAddressAndPort; import org.apache.cassandra.locator.Replica; import org.apache.cassandra.locator.ReplicaUtils; -import org.apache.cassandra.net.MessageOut; +import org.apache.cassandra.net.Message; import org.apache.cassandra.service.reads.ReadCallback; public class BlockingReadRepairTest extends AbstractReadRepairTest @@ -55,7 +55,7 @@ public InstrumentedReadRepairHandler(Map repairs, int maxBloc Map mutationsSent = new HashMap<>(); - protected void sendRR(MessageOut message, InetAddressAndPort endpoint) + protected void sendRR(Message message, InetAddressAndPort endpoint) { mutationsSent.put(endpoint, message.payload); } diff --git a/test/unit/org/apache/cassandra/service/reads/repair/DiagEventsBlockingReadRepairTest.java b/test/unit/org/apache/cassandra/service/reads/repair/DiagEventsBlockingReadRepairTest.java index befa07af1f1a..3bcd757f2d1b 100644 --- a/test/unit/org/apache/cassandra/service/reads/repair/DiagEventsBlockingReadRepairTest.java +++ b/test/unit/org/apache/cassandra/service/reads/repair/DiagEventsBlockingReadRepairTest.java @@ -44,7 +44,7 @@ import org.apache.cassandra.locator.EndpointsForRange; import org.apache.cassandra.locator.InetAddressAndPort; import org.apache.cassandra.locator.Replica; -import org.apache.cassandra.net.MessageOut; +import org.apache.cassandra.net.Message; import org.apache.cassandra.service.reads.ReadCallback; import org.apache.cassandra.service.reads.repair.ReadRepairEvent.ReadRepairEventType; @@ -187,7 +187,7 @@ private void onRepairEvent(PartitionRepairEvent e) Assert.assertNotNull(e.toMap()); } - protected void sendRR(MessageOut message, InetAddressAndPort endpoint) + protected void sendRR(Message message, InetAddressAndPort endpoint) { } } diff --git a/test/unit/org/apache/cassandra/service/reads/repair/ReadRepairTest.java b/test/unit/org/apache/cassandra/service/reads/repair/ReadRepairTest.java index c3f05c0777bd..232644db58db 100644 --- a/test/unit/org/apache/cassandra/service/reads/repair/ReadRepairTest.java +++ b/test/unit/org/apache/cassandra/service/reads/repair/ReadRepairTest.java @@ -51,7 +51,7 @@ import org.apache.cassandra.locator.Replica; import org.apache.cassandra.locator.ReplicaLayout; import org.apache.cassandra.locator.ReplicaUtils; -import org.apache.cassandra.net.MessageOut; +import org.apache.cassandra.net.Message; import org.apache.cassandra.schema.KeyspaceMetadata; import org.apache.cassandra.schema.KeyspaceParams; import org.apache.cassandra.schema.MigrationManager; @@ -82,7 +82,7 @@ public InstrumentedReadRepairHandler(Map repairs, int maxBloc Map mutationsSent = new HashMap<>(); - protected void sendRR(MessageOut message, InetAddressAndPort endpoint) + protected void sendRR(Message message, InetAddressAndPort endpoint) { mutationsSent.put(endpoint, message.payload); } diff --git a/test/unit/org/apache/cassandra/streaming/StreamSessionTest.java b/test/unit/org/apache/cassandra/streaming/StreamSessionTest.java deleted file mode 100644 index 7ea09ea5b908..000000000000 --- a/test/unit/org/apache/cassandra/streaming/StreamSessionTest.java +++ /dev/null @@ -1,67 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.cassandra.streaming; - -import java.net.InetAddress; -import java.net.UnknownHostException; -import java.util.Collections; -import java.util.UUID; - -import org.junit.BeforeClass; -import org.junit.Test; - -import org.apache.cassandra.config.DatabaseDescriptor; -import org.apache.cassandra.locator.InetAddressAndPort; -import org.apache.cassandra.net.MessagingServiceTest; - -import static org.junit.Assert.assertEquals; - -public class StreamSessionTest -{ - @BeforeClass - public static void beforeClass() throws UnknownHostException - { - DatabaseDescriptor.daemonInitialization(); - DatabaseDescriptor.setBackPressureStrategy(new MessagingServiceTest.MockBackPressureStrategy(Collections.emptyMap())); - DatabaseDescriptor.setBroadcastAddress(InetAddress.getByName("127.0.0.3")); - } - - @Test - public void testStreamSessionUsesCorrectRemoteIp_Succeeds() throws UnknownHostException - { - InetAddressAndPort localAddr = InetAddressAndPort.getByName("127.0.0.1:7000"); - InetAddressAndPort preferredAddr = InetAddressAndPort.getByName("127.0.0.2:7000"); - StreamSession streamSession = new StreamSession(StreamOperation.BOOTSTRAP, localAddr, - new DefaultConnectionFactory(), 0, UUID.randomUUID(), PreviewKind.ALL, - inetAddressAndPort -> preferredAddr); - - assertEquals(preferredAddr, streamSession.getMessageSender().getConnectionId().connectionAddress()); - } - - @Test - public void testStreamSessionUsesCorrectRemoteIpNullMapper_Succeeds() throws UnknownHostException - { - InetAddressAndPort localAddr = InetAddressAndPort.getByName("127.0.0.1:7000"); - - StreamSession streamSession = new StreamSession(StreamOperation.BOOTSTRAP, localAddr, - new DefaultConnectionFactory(), 0, UUID.randomUUID(), PreviewKind.ALL, (peer) -> null); - - assertEquals(localAddr, streamSession.getMessageSender().getConnectionId().connectionAddress()); - } -} diff --git a/test/unit/org/apache/cassandra/streaming/StreamTransferTaskTest.java b/test/unit/org/apache/cassandra/streaming/StreamTransferTaskTest.java index 8ebe333622b8..ae4e766ff764 100644 --- a/test/unit/org/apache/cassandra/streaming/StreamTransferTaskTest.java +++ b/test/unit/org/apache/cassandra/streaming/StreamTransferTaskTest.java @@ -76,14 +76,14 @@ public void tearDown() public void testScheduleTimeout() throws Exception { InetAddressAndPort peer = FBUtilities.getBroadcastAddressAndPort(); - StreamSession session = new StreamSession(StreamOperation.BOOTSTRAP, peer, (connectionId, protocolVersion) -> new EmbeddedChannel(), 0, UUID.randomUUID(), PreviewKind.ALL); + StreamSession session = new StreamSession(StreamOperation.BOOTSTRAP, peer, (template, messagingVersion) -> new EmbeddedChannel(), 0, UUID.randomUUID(), PreviewKind.ALL); ColumnFamilyStore cfs = Keyspace.open(KEYSPACE1).getColumnFamilyStore(CF_STANDARD); // create two sstables for (int i = 0; i < 2; i++) { SchemaLoader.insertData(KEYSPACE1, CF_STANDARD, i, 1); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); } // create streaming task that streams those two sstables @@ -132,7 +132,7 @@ public void testFailSessionDuringTransferShouldNotReleaseReferences() throws Exc for (int i = 0; i < 2; i++) { SchemaLoader.insertData(KEYSPACE1, CF_STANDARD, i, 1); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); } // create streaming task that streams those two sstables diff --git a/test/unit/org/apache/cassandra/streaming/StreamingTransferTest.java b/test/unit/org/apache/cassandra/streaming/StreamingTransferTest.java index 909e221ae28b..9746fd08449c 100644 --- a/test/unit/org/apache/cassandra/streaming/StreamingTransferTest.java +++ b/test/unit/org/apache/cassandra/streaming/StreamingTransferTest.java @@ -174,7 +174,7 @@ private List createAndTransfer(ColumnFamilyStore cfs, Mutator mutator, b long timestamp = 1234; for (int i = 1; i <= 3; i++) mutator.mutate("key" + i, "col" + i, timestamp); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); Util.compactAll(cfs, Integer.MAX_VALUE).get(); assertEquals(1, cfs.getLiveSSTables().size()); @@ -362,7 +362,7 @@ public void testTransferRangeTombstones() throws Exception .build() .apply(); - cfs.forceBlockingFlush(); + cfs.forceBlockingFlushToSSTable(); SSTableReader sstable = cfs.getLiveSSTables().iterator().next(); cfs.clearUnsafe(); diff --git a/test/unit/org/apache/cassandra/streaming/async/NettyStreamingMessageSenderTest.java b/test/unit/org/apache/cassandra/streaming/async/NettyStreamingMessageSenderTest.java index 52f097af6211..957869b38269 100644 --- a/test/unit/org/apache/cassandra/streaming/async/NettyStreamingMessageSenderTest.java +++ b/test/unit/org/apache/cassandra/streaming/async/NettyStreamingMessageSenderTest.java @@ -32,10 +32,10 @@ import org.junit.Test; import io.netty.channel.ChannelPromise; -import io.netty.channel.embedded.EmbeddedChannel; import org.apache.cassandra.config.DatabaseDescriptor; import org.apache.cassandra.locator.InetAddressAndPort; -import org.apache.cassandra.net.async.TestScheduledFuture; +import org.apache.cassandra.net.TestChannel; +import org.apache.cassandra.net.TestScheduledFuture; import org.apache.cassandra.streaming.PreviewKind; import org.apache.cassandra.streaming.StreamOperation; import org.apache.cassandra.streaming.StreamResultFuture; @@ -46,7 +46,7 @@ public class NettyStreamingMessageSenderTest { private static final InetAddressAndPort REMOTE_ADDR = InetAddressAndPort.getByAddressOverrideDefaults(InetAddresses.forString("127.0.0.2"), 0); - private EmbeddedChannel channel; + private TestChannel channel; private StreamSession session; private NettyStreamingMessageSender sender; private NettyStreamingMessageSender.FileStreamTask fileStreamTask; @@ -60,10 +60,10 @@ public static void before() @Before public void setUp() { - channel = new EmbeddedChannel(); + channel = new TestChannel(Integer.MAX_VALUE); channel.attr(NettyStreamingMessageSender.TRANSFERRING_FILE_ATTR).set(Boolean.FALSE); UUID pendingRepair = UUID.randomUUID(); - session = new StreamSession(StreamOperation.BOOTSTRAP, REMOTE_ADDR, (connectionId, protocolVersion) -> null, 0, pendingRepair, PreviewKind.ALL); + session = new StreamSession(StreamOperation.BOOTSTRAP, REMOTE_ADDR, (template, messagingVersion) -> null, 0, pendingRepair, PreviewKind.ALL); StreamResultFuture future = StreamResultFuture.initReceivingSide(0, UUID.randomUUID(), StreamOperation.REPAIR, REMOTE_ADDR, channel, pendingRepair, session.getPreviewKind()); session.init(future); sender = session.getMessageSender(); diff --git a/test/unit/org/apache/cassandra/streaming/async/StreamCompressionSerializerTest.java b/test/unit/org/apache/cassandra/streaming/async/StreamCompressionSerializerTest.java index e274f27e95aa..a88092ed40e1 100644 --- a/test/unit/org/apache/cassandra/streaming/async/StreamCompressionSerializerTest.java +++ b/test/unit/org/apache/cassandra/streaming/async/StreamCompressionSerializerTest.java @@ -31,17 +31,18 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; import io.netty.buffer.PooledByteBufAllocator; +import io.netty.buffer.Unpooled; import net.jpountz.lz4.LZ4Compressor; import net.jpountz.lz4.LZ4Factory; import net.jpountz.lz4.LZ4FastDecompressor; import org.apache.cassandra.config.DatabaseDescriptor; import org.apache.cassandra.io.util.DataInputBuffer; import org.apache.cassandra.io.util.FileUtils; -import org.apache.cassandra.streaming.messages.StreamMessage; +import org.apache.cassandra.net.MessagingService; public class StreamCompressionSerializerTest { - private static final int VERSION = StreamMessage.CURRENT_VERSION; + private static final int VERSION = MessagingService.current_version; private static final Random random = new Random(2347623847623L); private final ByteBufAllocator allocator = PooledByteBufAllocator.DEFAULT; @@ -50,7 +51,7 @@ public class StreamCompressionSerializerTest private final LZ4FastDecompressor decompressor = LZ4Factory.fastestInstance().fastDecompressor(); private ByteBuffer input; - private ByteBuf compressed; + private ByteBuffer compressed; private ByteBuf output; @BeforeClass @@ -64,8 +65,8 @@ public void tearDown() { if (input != null) FileUtils.clean(input); - if (compressed != null && compressed.refCnt() > 0) - compressed.release(compressed.refCnt()); + if (compressed != null) + FileUtils.clean(compressed); if (output != null && output.refCnt() > 0) output.release(output.refCnt()); } @@ -74,10 +75,9 @@ public void tearDown() public void roundTrip_HappyPath_NotReadabaleByteBuffer() throws IOException { populateInput(); - compressed = serializer.serialize(compressor, input, VERSION); + StreamCompressionSerializer.serialize(compressor, input, VERSION).write(size -> compressed = ByteBuffer.allocateDirect(size)); input.flip(); - ByteBuffer compressedNioBuffer = compressed.nioBuffer(0, compressed.writerIndex()); - output = serializer.deserialize(decompressor, new DataInputBuffer(compressedNioBuffer, false), VERSION); + output = serializer.deserialize(decompressor, new DataInputBuffer(compressed, false), VERSION); validateResults(); } @@ -101,9 +101,14 @@ private void validateResults() public void roundTrip_HappyPath_ReadabaleByteBuffer() throws IOException { populateInput(); - compressed = serializer.serialize(compressor, input, VERSION); + StreamCompressionSerializer.serialize(compressor, input, VERSION) + .write(size -> { + if (compressed != null) + FileUtils.clean(compressed); + return compressed = ByteBuffer.allocateDirect(size); + }); input.flip(); - output = serializer.deserialize(decompressor, new ByteBufRCH(compressed), VERSION); + output = serializer.deserialize(decompressor, new ByteBufRCH(Unpooled.wrappedBuffer(compressed)), VERSION); validateResults(); } diff --git a/test/unit/org/apache/cassandra/streaming/async/StreamingInboundHandlerTest.java b/test/unit/org/apache/cassandra/streaming/async/StreamingInboundHandlerTest.java index 0a135960f970..6a2afe8e4024 100644 --- a/test/unit/org/apache/cassandra/streaming/async/StreamingInboundHandlerTest.java +++ b/test/unit/org/apache/cassandra/streaming/async/StreamingInboundHandlerTest.java @@ -18,9 +18,6 @@ package org.apache.cassandra.streaming.async; -import java.io.EOFException; -import java.io.IOException; -import java.util.ArrayList; import java.util.UUID; import com.google.common.net.InetAddresses; @@ -33,10 +30,9 @@ import io.netty.buffer.ByteBuf; import io.netty.channel.embedded.EmbeddedChannel; import org.apache.cassandra.config.DatabaseDescriptor; -import org.apache.cassandra.io.sstable.format.SSTableFormat; -import org.apache.cassandra.io.sstable.format.big.BigFormat; import org.apache.cassandra.locator.InetAddressAndPort; -import org.apache.cassandra.net.async.RebufferingByteBufDataInputPlus; +import org.apache.cassandra.net.MessagingService; +import org.apache.cassandra.net.AsyncStreamingInputPlus; import org.apache.cassandra.schema.TableId; import org.apache.cassandra.streaming.PreviewKind; import org.apache.cassandra.streaming.StreamManager; @@ -45,19 +41,18 @@ import org.apache.cassandra.streaming.StreamSession; import org.apache.cassandra.streaming.async.StreamingInboundHandler.SessionIdentifier; import org.apache.cassandra.streaming.messages.CompleteMessage; -import org.apache.cassandra.streaming.messages.StreamMessageHeader; import org.apache.cassandra.streaming.messages.IncomingStreamMessage; import org.apache.cassandra.streaming.messages.StreamInitMessage; -import org.apache.cassandra.streaming.messages.StreamMessage; +import org.apache.cassandra.streaming.messages.StreamMessageHeader; public class StreamingInboundHandlerTest { - private static final int VERSION = StreamMessage.CURRENT_VERSION; + private static final int VERSION = MessagingService.current_version; private static final InetAddressAndPort REMOTE_ADDR = InetAddressAndPort.getByAddressOverrideDefaults(InetAddresses.forString("127.0.0.2"), 0); private StreamingInboundHandler handler; private EmbeddedChannel channel; - private RebufferingByteBufDataInputPlus buffers; + private AsyncStreamingInputPlus buffers; private ByteBuf buf; @BeforeClass @@ -71,7 +66,7 @@ public void setup() { handler = new StreamingInboundHandler(REMOTE_ADDR, VERSION, null); channel = new EmbeddedChannel(handler); - buffers = new RebufferingByteBufDataInputPlus(1 << 9, 1 << 10, channel.config()); + buffers = new AsyncStreamingInputPlus(channel); handler.setPendingBuffers(buffers); } @@ -88,19 +83,19 @@ public void tearDown() } @Test - public void channelRead_Normal() throws EOFException + public void channelRead_Normal() { - Assert.assertEquals(0, buffers.available()); + Assert.assertEquals(0, buffers.unsafeAvailable()); int size = 8; buf = channel.alloc().buffer(size); buf.writerIndex(size); channel.writeInbound(buf); - Assert.assertEquals(size, buffers.available()); + Assert.assertEquals(size, buffers.unsafeAvailable()); Assert.assertFalse(channel.releaseInbound()); } - @Test (expected = EOFException.class) - public void channelRead_Closed() throws EOFException + @Test + public void channelRead_Closed() { int size = 8; buf = channel.alloc().buffer(size); @@ -108,21 +103,21 @@ public void channelRead_Closed() throws EOFException buf.writerIndex(size); handler.close(); channel.writeInbound(buf); - Assert.assertEquals(0, buffers.available()); + Assert.assertEquals(0, buffers.unsafeAvailable()); Assert.assertEquals(0, buf.refCnt()); Assert.assertFalse(channel.releaseInbound()); } @Test - public void channelRead_WrongObject() throws EOFException + public void channelRead_WrongObject() { channel.writeInbound("homer"); - Assert.assertEquals(0, buffers.available()); + Assert.assertEquals(0, buffers.unsafeAvailable()); Assert.assertFalse(channel.releaseInbound()); } @Test - public void StreamDeserializingTask_deriveSession_StreamInitMessage() throws InterruptedException, IOException + public void StreamDeserializingTask_deriveSession_StreamInitMessage() { StreamInitMessage msg = new StreamInitMessage(REMOTE_ADDR, 0, UUID.randomUUID(), StreamOperation.REPAIR, UUID.randomUUID(), PreviewKind.ALL); StreamingInboundHandler.StreamDeserializingTask task = handler.new StreamDeserializingTask(sid -> createSession(sid), null, channel); @@ -132,11 +127,11 @@ public void StreamDeserializingTask_deriveSession_StreamInitMessage() throws Int private StreamSession createSession(SessionIdentifier sid) { - return new StreamSession(StreamOperation.BOOTSTRAP, sid.from, (connectionId, protocolVersion) -> null, sid.sessionIndex, UUID.randomUUID(), PreviewKind.ALL); + return new StreamSession(StreamOperation.BOOTSTRAP, sid.from, (template, messagingVersion) -> null, sid.sessionIndex, UUID.randomUUID(), PreviewKind.ALL); } @Test (expected = IllegalStateException.class) - public void StreamDeserializingTask_deriveSession_NoSession() throws InterruptedException, IOException + public void StreamDeserializingTask_deriveSession_NoSession() { CompleteMessage msg = new CompleteMessage(); StreamingInboundHandler.StreamDeserializingTask task = handler.new StreamDeserializingTask(sid -> createSession(sid), null, channel); @@ -144,7 +139,7 @@ public void StreamDeserializingTask_deriveSession_NoSession() throws Interrupted } @Test (expected = IllegalStateException.class) - public void StreamDeserializingTask_deriveSession_IFM_NoSession() throws InterruptedException, IOException + public void StreamDeserializingTask_deriveSession_IFM_NoSession() { StreamMessageHeader header = new StreamMessageHeader(TableId.generate(), REMOTE_ADDR, UUID.randomUUID(), 0, 0, 0, UUID.randomUUID()); @@ -154,7 +149,7 @@ public void StreamDeserializingTask_deriveSession_IFM_NoSession() throws Interru } @Test - public void StreamDeserializingTask_deriveSession_IFM_HasSession() throws InterruptedException, IOException + public void StreamDeserializingTask_deriveSession_IFM_HasSession() { UUID planId = UUID.randomUUID(); StreamResultFuture future = StreamResultFuture.initReceivingSide(0, planId, StreamOperation.REPAIR, REMOTE_ADDR, channel, UUID.randomUUID(), PreviewKind.ALL); diff --git a/test/unit/org/apache/cassandra/transport/IdleDisconnectTest.java b/test/unit/org/apache/cassandra/transport/IdleDisconnectTest.java new file mode 100644 index 000000000000..2c8adeabae0f --- /dev/null +++ b/test/unit/org/apache/cassandra/transport/IdleDisconnectTest.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.transport; + +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.TimeUnit; + +import org.junit.Assert; +import org.junit.BeforeClass; +import org.junit.Test; + +import org.apache.cassandra.config.DatabaseDescriptor; +import org.apache.cassandra.cql3.CQLTester; +import org.apache.cassandra.db.ConsistencyLevel; + +public class IdleDisconnectTest extends CQLTester +{ + private static final long TIMEOUT = 2000L; + + @BeforeClass + public static void setUp() + { + requireNetwork(); + DatabaseDescriptor.setNativeTransportIdleTimeout(TIMEOUT); + } + + @Test + public void testIdleDisconnect() throws Throwable + { + DatabaseDescriptor.setNativeTransportIdleTimeout(TIMEOUT); + try (SimpleClient client = new SimpleClient(nativeAddr.getHostAddress(), nativePort)) + { + client.connect(false, false); + Assert.assertTrue(client.channel.isOpen()); + long start = System.currentTimeMillis(); + CompletableFuture.runAsync(() -> { + while (!Thread.currentThread().isInterrupted() && client.channel.isOpen()); + }).get(30, TimeUnit.SECONDS); + Assert.assertFalse(client.channel.isOpen()); + Assert.assertTrue(System.currentTimeMillis() - start >= TIMEOUT); + } + } + + @Test + public void testIdleDisconnectProlonged() throws Throwable + { + long sleepTime = 1000; + try (SimpleClient client = new SimpleClient(nativeAddr.getHostAddress(), nativePort)) + { + client.connect(false, false); + Assert.assertTrue(client.channel.isOpen()); + long start = System.currentTimeMillis(); + Thread.sleep(sleepTime); + client.execute("SELECT * FROM system.peers", ConsistencyLevel.ONE); + CompletableFuture.runAsync(() -> { + while (!Thread.currentThread().isInterrupted() && client.channel.isOpen()); + }).get(30, TimeUnit.SECONDS); + Assert.assertFalse(client.channel.isOpen()); + Assert.assertTrue(System.currentTimeMillis() - start >= TIMEOUT + sleepTime); + } + } +} diff --git a/test/unit/org/apache/cassandra/transport/InflightRequestPayloadTrackerTest.java b/test/unit/org/apache/cassandra/transport/InflightRequestPayloadTrackerTest.java new file mode 100644 index 000000000000..c9a9a02e8181 --- /dev/null +++ b/test/unit/org/apache/cassandra/transport/InflightRequestPayloadTrackerTest.java @@ -0,0 +1,258 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.transport; + +import org.junit.After; +import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.runner.RunWith; + +import org.apache.cassandra.OrderedJUnit4ClassRunner; +import org.apache.cassandra.config.DatabaseDescriptor; +import org.apache.cassandra.config.EncryptionOptions; +import org.apache.cassandra.cql3.CQLTester; +import org.apache.cassandra.cql3.QueryOptions; +import org.apache.cassandra.cql3.QueryProcessor; +import org.apache.cassandra.exceptions.OverloadedException; +import org.apache.cassandra.transport.messages.QueryMessage; + +@RunWith(OrderedJUnit4ClassRunner.class) +public class InflightRequestPayloadTrackerTest extends CQLTester +{ + @BeforeClass + public static void setUp() + { + DatabaseDescriptor.setNativeTransportMaxConcurrentRequestsInBytesPerIp(600); + DatabaseDescriptor.setNativeTransportMaxConcurrentRequestsInBytes(600); + requireNetwork(); + } + + @AfterClass + public static void tearDown() + { + DatabaseDescriptor.setNativeTransportMaxConcurrentRequestsInBytesPerIp(3000000000L); + DatabaseDescriptor.setNativeTransportMaxConcurrentRequestsInBytes(5000000000L); + } + + @After + public void dropCreatedTable() + { + try + { + QueryProcessor.executeOnceInternal("DROP TABLE " + KEYSPACE + ".atable"); + } + catch (Throwable t) + { + // ignore + } + } + + @Test + public void testQueryExecutionWithThrowOnOverload() throws Throwable + { + SimpleClient client = new SimpleClient(nativeAddr.getHostAddress(), + nativePort, + ProtocolVersion.V5, + true, + new EncryptionOptions()); + + try + { + client.connect(false, false, true); + QueryOptions queryOptions = QueryOptions.create( + QueryOptions.DEFAULT.getConsistency(), + QueryOptions.DEFAULT.getValues(), + QueryOptions.DEFAULT.skipMetadata(), + QueryOptions.DEFAULT.getPageSize(), + QueryOptions.DEFAULT.getPagingState(), + QueryOptions.DEFAULT.getSerialConsistency(), + ProtocolVersion.V5, + KEYSPACE); + + QueryMessage queryMessage = new QueryMessage("CREATE TABLE atable (pk1 int PRIMARY KEY, v text)", + queryOptions); + client.execute(queryMessage); + } + finally + { + client.close(); + } + } + + @Test + public void testQueryExecutionWithoutThrowOnOverload() throws Throwable + { + SimpleClient client = new SimpleClient(nativeAddr.getHostAddress(), + nativePort, + ProtocolVersion.V5, + true, + new EncryptionOptions()); + + try + { + client.connect(false, false, false); + QueryOptions queryOptions = QueryOptions.create( + QueryOptions.DEFAULT.getConsistency(), + QueryOptions.DEFAULT.getValues(), + QueryOptions.DEFAULT.skipMetadata(), + QueryOptions.DEFAULT.getPageSize(), + QueryOptions.DEFAULT.getPagingState(), + QueryOptions.DEFAULT.getSerialConsistency(), + ProtocolVersion.V5, + KEYSPACE); + + QueryMessage queryMessage = new QueryMessage("CREATE TABLE atable (pk int PRIMARY KEY, v text)", + queryOptions); + client.execute(queryMessage); + queryMessage = new QueryMessage("SELECT * FROM atable", + queryOptions); + client.execute(queryMessage); + } + finally + { + client.close(); + } + } + + @Test + public void testQueryExecutionWithoutThrowOnOverloadAndInflightLimitedExceeded() throws Throwable + { + SimpleClient client = new SimpleClient(nativeAddr.getHostAddress(), + nativePort, + ProtocolVersion.V5, + true, + new EncryptionOptions()); + + try + { + client.connect(false, false, false); + QueryOptions queryOptions = QueryOptions.create( + QueryOptions.DEFAULT.getConsistency(), + QueryOptions.DEFAULT.getValues(), + QueryOptions.DEFAULT.skipMetadata(), + QueryOptions.DEFAULT.getPageSize(), + QueryOptions.DEFAULT.getPagingState(), + QueryOptions.DEFAULT.getSerialConsistency(), + ProtocolVersion.V5, + KEYSPACE); + + QueryMessage queryMessage = new QueryMessage("CREATE TABLE atable (pk int PRIMARY KEY, v text)", + queryOptions); + client.execute(queryMessage); + + queryMessage = new QueryMessage("INSERT INTO atable (pk, v) VALUES (1, 'aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa')", + queryOptions); + client.execute(queryMessage); + } + finally + { + client.close(); + } + } + + @Test + public void testOverloadedExceptionForEndpointInflightLimit() throws Throwable + { + SimpleClient client = new SimpleClient(nativeAddr.getHostAddress(), + nativePort, + ProtocolVersion.V5, + true, + new EncryptionOptions()); + + try + { + client.connect(false, false, true); + QueryOptions queryOptions = QueryOptions.create( + QueryOptions.DEFAULT.getConsistency(), + QueryOptions.DEFAULT.getValues(), + QueryOptions.DEFAULT.skipMetadata(), + QueryOptions.DEFAULT.getPageSize(), + QueryOptions.DEFAULT.getPagingState(), + QueryOptions.DEFAULT.getSerialConsistency(), + ProtocolVersion.V5, + KEYSPACE); + + QueryMessage queryMessage = new QueryMessage("CREATE TABLE atable (pk int PRIMARY KEY, v text)", + queryOptions); + client.execute(queryMessage); + + queryMessage = new QueryMessage("INSERT INTO atable (pk, v) VALUES (1, 'aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa')", + queryOptions); + try + { + client.execute(queryMessage); + Assert.fail(); + } + catch (RuntimeException e) + { + Assert.assertTrue(e.getCause() instanceof OverloadedException); + } + } + finally + { + client.close(); + } + } + + @Test + public void testOverloadedExceptionForOverallInflightLimit() throws Throwable + { + SimpleClient client = new SimpleClient(nativeAddr.getHostAddress(), + nativePort, + ProtocolVersion.V5, + true, + new EncryptionOptions()); + + try + { + client.connect(false, false, true); + QueryOptions queryOptions = QueryOptions.create( + QueryOptions.DEFAULT.getConsistency(), + QueryOptions.DEFAULT.getValues(), + QueryOptions.DEFAULT.skipMetadata(), + QueryOptions.DEFAULT.getPageSize(), + QueryOptions.DEFAULT.getPagingState(), + QueryOptions.DEFAULT.getSerialConsistency(), + ProtocolVersion.V5, + KEYSPACE); + + QueryMessage queryMessage = new QueryMessage("CREATE TABLE atable (pk int PRIMARY KEY, v text)", + queryOptions); + client.execute(queryMessage); + + queryMessage = new QueryMessage("INSERT INTO atable (pk, v) VALUES (1, 'aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa')", + queryOptions); + try + { + client.execute(queryMessage); + Assert.fail(); + } + catch (RuntimeException e) + { + Assert.assertTrue(e.getCause() instanceof OverloadedException); + } + } + finally + { + client.close(); + } + } +} \ No newline at end of file diff --git a/test/unit/org/apache/cassandra/utils/ByteBufferUtilTest.java b/test/unit/org/apache/cassandra/utils/ByteBufferUtilTest.java index c5de60b64c49..4ae8626caeb2 100644 --- a/test/unit/org/apache/cassandra/utils/ByteBufferUtilTest.java +++ b/test/unit/org/apache/cassandra/utils/ByteBufferUtilTest.java @@ -153,11 +153,11 @@ private void checkArrayCopy(ByteBuffer bb) { byte[] bytes = new byte[s.length()]; - ByteBufferUtil.arrayCopy(bb, bb.position(), bytes, 0, s.length()); + ByteBufferUtil.copyBytes(bb, bb.position(), bytes, 0, s.length()); assertArrayEquals(s.getBytes(), bytes); bytes = new byte[5]; - ByteBufferUtil.arrayCopy(bb, bb.position() + 3, bytes, 1, 4); + ByteBufferUtil.copyBytes(bb, bb.position() + 3, bytes, 1, 4); assertArrayEquals(Arrays.copyOfRange(s.getBytes(), 3, 7), Arrays.copyOfRange(bytes, 1, 5)); } diff --git a/test/unit/org/apache/cassandra/utils/CoalescingStrategiesTest.java b/test/unit/org/apache/cassandra/utils/CoalescingStrategiesTest.java deleted file mode 100644 index 8877fe9c1ace..000000000000 --- a/test/unit/org/apache/cassandra/utils/CoalescingStrategiesTest.java +++ /dev/null @@ -1,124 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.cassandra.utils; - -import java.util.concurrent.TimeUnit; - -import org.junit.Assert; -import org.junit.Test; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import org.apache.cassandra.utils.CoalescingStrategies.Coalescable; -import org.apache.cassandra.utils.CoalescingStrategies.CoalescingStrategy; -import org.apache.cassandra.utils.CoalescingStrategies.FixedCoalescingStrategy; -import org.apache.cassandra.utils.CoalescingStrategies.MovingAverageCoalescingStrategy; -import org.apache.cassandra.utils.CoalescingStrategies.TimeHorizonMovingAverageCoalescingStrategy; - -public class CoalescingStrategiesTest -{ - private static final Logger logger = LoggerFactory.getLogger(CoalescingStrategiesTest.class); - private static final int WINDOW_IN_MICROS = 200; - private static final long WINDOW_IN_NANOS = TimeUnit.MICROSECONDS.toNanos(WINDOW_IN_MICROS); - private static final String DISPLAY_NAME = "Stupendopotamus"; - - static class SimpleCoalescable implements Coalescable - { - final long timestampNanos; - - SimpleCoalescable(long timestampNanos) - { - this.timestampNanos = timestampNanos; - } - - public long timestampNanos() - { - return timestampNanos; - } - } - - static long toNanos(long micros) - { - return TimeUnit.MICROSECONDS.toNanos(micros); - } - - @Test - public void testFixedCoalescingStrategy() - { - CoalescingStrategy cs = new FixedCoalescingStrategy(WINDOW_IN_MICROS, logger, DISPLAY_NAME); - Assert.assertEquals(WINDOW_IN_NANOS, cs.currentCoalescingTimeNanos()); - } - - @Test - public void testMovingAverageCoalescingStrategy_DoCoalesce() - { - CoalescingStrategy cs = new MovingAverageCoalescingStrategy(WINDOW_IN_MICROS, logger, DISPLAY_NAME); - - for (int i = 0; i < MovingAverageCoalescingStrategy.SAMPLE_SIZE; i++) - cs.newArrival(new SimpleCoalescable(toNanos(i))); - Assert.assertTrue(0 < cs.currentCoalescingTimeNanos()); - } - - @Test - public void testMovingAverageCoalescingStrategy_DoNotCoalesce() - { - CoalescingStrategy cs = new MovingAverageCoalescingStrategy(WINDOW_IN_MICROS, logger, DISPLAY_NAME); - - for (int i = 0; i < MovingAverageCoalescingStrategy.SAMPLE_SIZE; i++) - cs.newArrival(new SimpleCoalescable(toNanos(WINDOW_IN_MICROS + i) * i)); - Assert.assertTrue(0 >= cs.currentCoalescingTimeNanos()); - } - - @Test - public void testTimeHorizonStrategy_DoCoalesce() - { - long initialEpoch = 0; - CoalescingStrategy cs = new TimeHorizonMovingAverageCoalescingStrategy(WINDOW_IN_MICROS, logger, DISPLAY_NAME, initialEpoch); - - for (int i = 0; i < 10_000; i++) - cs.newArrival(new SimpleCoalescable(toNanos(i))); - Assert.assertTrue(0 < cs.currentCoalescingTimeNanos()); - } - - @Test - public void testTimeHorizonStrategy_DoNotCoalesce() - { - long initialEpoch = 0; - CoalescingStrategy cs = new TimeHorizonMovingAverageCoalescingStrategy(WINDOW_IN_MICROS, logger, DISPLAY_NAME, initialEpoch); - - for (int i = 0; i < 1_000_000; i++) - cs.newArrival(new SimpleCoalescable(toNanos(WINDOW_IN_MICROS + i) * i)); - Assert.assertTrue(0 >= cs.currentCoalescingTimeNanos()); - } - - @Test - public void determineCoalescingTime_LargeAverageGap() - { - Assert.assertTrue(0 >= CoalescingStrategies.determineCoalescingTime(WINDOW_IN_NANOS * 2, WINDOW_IN_NANOS)); - Assert.assertTrue(0 >= CoalescingStrategies.determineCoalescingTime(Integer.MAX_VALUE, WINDOW_IN_NANOS)); - } - - @Test - public void determineCoalescingTime_SmallAvgGap() - { - Assert.assertTrue(WINDOW_IN_NANOS >= CoalescingStrategies.determineCoalescingTime(WINDOW_IN_NANOS / 2, WINDOW_IN_NANOS)); - Assert.assertTrue(WINDOW_IN_NANOS >= CoalescingStrategies.determineCoalescingTime(WINDOW_IN_NANOS - 1, WINDOW_IN_NANOS)); - Assert.assertTrue(WINDOW_IN_NANOS >= CoalescingStrategies.determineCoalescingTime(1, WINDOW_IN_NANOS)); - Assert.assertEquals(WINDOW_IN_NANOS, CoalescingStrategies.determineCoalescingTime(0, WINDOW_IN_NANOS)); - } -} diff --git a/test/unit/org/apache/cassandra/utils/FreeRunningClock.java b/test/unit/org/apache/cassandra/utils/FreeRunningClock.java index 83c8db703aaa..d85383389dbe 100644 --- a/test/unit/org/apache/cassandra/utils/FreeRunningClock.java +++ b/test/unit/org/apache/cassandra/utils/FreeRunningClock.java @@ -20,23 +20,41 @@ import java.util.concurrent.TimeUnit; /** - * A freely adjustable clock that can be used for unit testing. See {@link Clock#instance} how to + * A freely adjustable clock that can be used for unit testing. See {@link MonotonicClock#instance} how to * enable this class. */ -public class FreeRunningClock extends Clock +public class FreeRunningClock implements MonotonicClock { private long nanoTime = 0; @Override - public long nanoTime() + public long now() { return nanoTime; } @Override - public long currentTimeMillis() + public long error() { - return TimeUnit.NANOSECONDS.toMillis(nanoTime()); + return 0; + } + + @Override + public MonotonicClockTranslation translate() + { + throw new UnsupportedOperationException(); + } + + @Override + public boolean isAfter(long instant) + { + return instant > nanoTime; + } + + @Override + public boolean isAfter(long now, long instant) + { + return now > instant; } public void advance(long time, TimeUnit unit) diff --git a/test/unit/org/apache/cassandra/utils/MerkleTreeTest.java b/test/unit/org/apache/cassandra/utils/MerkleTreeTest.java index c213271c0f88..36ae4a0479fb 100644 --- a/test/unit/org/apache/cassandra/utils/MerkleTreeTest.java +++ b/test/unit/org/apache/cassandra/utils/MerkleTreeTest.java @@ -1,29 +1,27 @@ /* -* Licensed to the Apache Software Foundation (ASF) under one -* or more contributor license agreements. See the NOTICE file -* distributed with this work for additional information -* regarding copyten ownership. The ASF licenses this file -* to you under the Apache License, Version 2.0 (the -* "License"); you may not use this file except in compliance -* with the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, -* software distributed under the License is distributed on an -* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -* KIND, either express or implied. See the License for the -* specific language governing permissions and limitations -* under the License. -*/ + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package org.apache.cassandra.utils; +import java.io.IOException; import java.math.BigInteger; import java.nio.ByteBuffer; import java.util.*; -import com.google.common.collect.Lists; - import org.junit.Assert; import org.junit.Before; import org.junit.Test; @@ -36,20 +34,24 @@ import org.apache.cassandra.dht.Range; import org.apache.cassandra.dht.Token; import org.apache.cassandra.io.util.DataInputBuffer; -import org.apache.cassandra.io.util.DataInputPlus; import org.apache.cassandra.io.util.DataOutputBuffer; import org.apache.cassandra.net.MessagingService; -import org.apache.cassandra.utils.MerkleTree.Hashable; import org.apache.cassandra.utils.MerkleTree.RowHash; import org.apache.cassandra.utils.MerkleTree.TreeRange; import org.apache.cassandra.utils.MerkleTree.TreeRangeIterator; +import static com.google.common.collect.Lists.newArrayList; import static org.apache.cassandra.utils.MerkleTree.RECOMMENDED_DEPTH; import static org.junit.Assert.*; public class MerkleTreeTest { - public static byte[] DUMMY = "blah".getBytes(); + private static final byte[] DUMMY = digest("dummy"); + + private static byte[] digest(String string) + { + return HashingUtils.newMessageDigest("SHA-256").digest(string.getBytes()); + } /** * If a test assumes that the tree is 8 units wide, then it should set this value @@ -68,6 +70,9 @@ private Range fullRange() @Before public void setup() { + DatabaseDescriptor.clientInitialization(); + DatabaseDescriptor.useOffheapMerkleTrees(false); + TOKEN_SCALE = new BigInteger("8"); partitioner = RandomPartitioner.instance; // TODO need to trickle TokenSerializer @@ -171,7 +176,7 @@ public void testInvalids() Iterator ranges; // (zero, zero] - ranges = mt.invalids(); + ranges = mt.rangeIterator(); assertEquals(new Range<>(tok(-1), tok(-1)), ranges.next()); assertFalse(ranges.hasNext()); @@ -181,7 +186,7 @@ public void testInvalids() mt.split(tok(6)); mt.split(tok(3)); mt.split(tok(5)); - ranges = mt.invalids(); + ranges = mt.rangeIterator(); assertEquals(new Range<>(tok(6), tok(-1)), ranges.next()); assertEquals(new Range<>(tok(-1), tok(2)), ranges.next()); assertEquals(new Range<>(tok(2), tok(3)), ranges.next()); @@ -200,7 +205,7 @@ public void testHashFull() Range range = new Range<>(tok(-1), tok(-1)); // (zero, zero] - assertNull(mt.hash(range)); + assertFalse(mt.hashesRange(range)); // validate the range mt.get(tok(-1)).hash(val); @@ -223,11 +228,12 @@ public void testHashPartial() // (zero,two] (two,four] (four, zero] mt.split(tok(4)); mt.split(tok(2)); - assertNull(mt.hash(left)); - assertNull(mt.hash(partial)); - assertNull(mt.hash(right)); - assertNull(mt.hash(linvalid)); - assertNull(mt.hash(rinvalid)); + + assertFalse(mt.hashesRange(left)); + assertFalse(mt.hashesRange(partial)); + assertFalse(mt.hashesRange(right)); + assertFalse(mt.hashesRange(linvalid)); + assertFalse(mt.hashesRange(rinvalid)); // validate the range mt.get(tok(2)).hash(val); @@ -237,8 +243,8 @@ public void testHashPartial() assertHashEquals(leftval, mt.hash(left)); assertHashEquals(partialval, mt.hash(partial)); assertHashEquals(val, mt.hash(right)); - assertNull(mt.hash(linvalid)); - assertNull(mt.hash(rinvalid)); + assertFalse(mt.hashesRange(linvalid)); + assertFalse(mt.hashesRange(rinvalid)); } @Test @@ -258,10 +264,6 @@ public void testHashInner() mt.split(tok(2)); mt.split(tok(6)); mt.split(tok(1)); - assertNull(mt.hash(full)); - assertNull(mt.hash(lchild)); - assertNull(mt.hash(rchild)); - assertNull(mt.hash(invalid)); // validate the range mt.get(tok(1)).hash(val); @@ -270,10 +272,14 @@ public void testHashInner() mt.get(tok(6)).hash(val); mt.get(tok(-1)).hash(val); + assertTrue(mt.hashesRange(full)); + assertTrue(mt.hashesRange(lchild)); + assertTrue(mt.hashesRange(rchild)); + assertFalse(mt.hashesRange(invalid)); + assertHashEquals(fullval, mt.hash(full)); assertHashEquals(lchildval, mt.hash(lchild)); assertHashEquals(rchildval, mt.hash(rchild)); - assertNull(mt.hash(invalid)); } @Test @@ -294,9 +300,6 @@ public void testHashDegenerate() mt.split(tok(4)); mt.split(tok(2)); mt.split(tok(1)); - assertNull(mt.hash(full)); - assertNull(mt.hash(childfull)); - assertNull(mt.hash(invalid)); // validate the range mt.get(tok(1)).hash(val); @@ -306,9 +309,12 @@ public void testHashDegenerate() mt.get(tok(16)).hash(val); mt.get(tok(-1)).hash(val); + assertTrue(mt.hashesRange(full)); + assertTrue(mt.hashesRange(childfull)); + assertFalse(mt.hashesRange(invalid)); + assertHashEquals(fullval, mt.hash(full)); assertHashEquals(childfullval, mt.hash(childfull)); - assertNull(mt.hash(invalid)); } @Test @@ -326,7 +332,7 @@ public void testHashRandom() } // validate the tree - TreeRangeIterator ranges = mt.invalids(); + TreeRangeIterator ranges = mt.rangeIterator(); for (TreeRange range : ranges) range.addHash(new RowHash(range.right, new byte[0], 0)); @@ -355,7 +361,7 @@ public void testValidateTree() mt.split(tok(6)); mt.split(tok(10)); - ranges = mt.invalids(); + ranges = mt.rangeIterator(); ranges.next().addAll(new HIterator(2, 4)); // (-1,4]: depth 2 ranges.next().addAll(new HIterator(6)); // (4,6] ranges.next().addAll(new HIterator(8)); // (6,8] @@ -372,7 +378,7 @@ public void testValidateTree() mt2.split(tok(9)); mt2.split(tok(11)); - ranges = mt2.invalids(); + ranges = mt2.rangeIterator(); ranges.next().addAll(new HIterator(2)); // (-1,2] ranges.next().addAll(new HIterator(4)); // (2,4] ranges.next().addAll(new HIterator(6, 8)); // (4,8]: depth 2 @@ -395,19 +401,33 @@ public void testSerialization() throws Exception // populate and validate the tree mt.maxsize(256); mt.init(); - for (TreeRange range : mt.invalids()) + for (TreeRange range : mt.rangeIterator()) range.addAll(new HIterator(range.right)); byte[] initialhash = mt.hash(full); DataOutputBuffer out = new DataOutputBuffer(); - MerkleTree.serializer.serialize(mt, out, MessagingService.current_version); + mt.serialize(out, MessagingService.current_version); byte[] serialized = out.toByteArray(); - DataInputPlus in = new DataInputBuffer(serialized); - MerkleTree restored = MerkleTree.serializer.deserialize(in, MessagingService.current_version); + MerkleTree restoredOnHeap = + MerkleTree.deserialize(new DataInputBuffer(serialized), false, MessagingService.current_version); + MerkleTree restoredOffHeap = + MerkleTree.deserialize(new DataInputBuffer(serialized), true, MessagingService.current_version); + MerkleTree movedOffHeap = mt.moveOffHeap(); + + assertHashEquals(initialhash, restoredOnHeap.hash(full)); + assertHashEquals(initialhash, restoredOffHeap.hash(full)); + assertHashEquals(initialhash, movedOffHeap.hash(full)); + + assertEquals(mt, restoredOnHeap); + assertEquals(mt, restoredOffHeap); + assertEquals(mt, movedOffHeap); - assertHashEquals(initialhash, restored.hash(full)); + assertEquals(restoredOnHeap, restoredOffHeap); + assertEquals(restoredOnHeap, movedOffHeap); + + assertEquals(restoredOffHeap, movedOffHeap); } @Test @@ -420,9 +440,9 @@ public void testDifference() mt2.init(); // add dummy hashes to both trees - for (TreeRange range : mt.invalids()) + for (TreeRange range : mt.rangeIterator()) range.addAll(new HIterator(range.right)); - for (TreeRange range : mt2.invalids()) + for (TreeRange range : mt2.rangeIterator()) range.addAll(new HIterator(range.right)); TreeRange leftmost = null; @@ -431,14 +451,14 @@ public void testDifference() mt.maxsize(maxsize + 2); // give some room for splitting // split the leftmost - Iterator ranges = mt.invalids(); + Iterator ranges = mt.rangeIterator(); leftmost = ranges.next(); mt.split(leftmost.right); // set the hashes for the leaf of the created split middle = mt.get(leftmost.right); - middle.hash("arbitrary!".getBytes()); - mt.get(partitioner.midpoint(leftmost.left, leftmost.right)).hash("even more arbitrary!".getBytes()); + middle.hash(digest("arbitrary!")); + mt.get(partitioner.midpoint(leftmost.left, leftmost.right)).hash(digest("even more arbitrary!")); // trees should disagree for (leftmost.left, middle.right] List diffs = MerkleTree.difference(mt, mt2); @@ -461,22 +481,23 @@ public void differenceSmallRange() MerkleTree rtree = new MerkleTree(partitioner, range, RECOMMENDED_DEPTH, 16); rtree.init(); - byte[] h1 = "asdf".getBytes(); - byte[] h2 = "hjkl".getBytes(); + byte[] h1 = digest("asdf"); + byte[] h2 = digest("hjkl"); // add dummy hashes to both trees - for (TreeRange tree : ltree.invalids()) + for (TreeRange tree : ltree.rangeIterator()) { tree.addHash(new RowHash(range.right, h1, h1.length)); } - for (TreeRange tree : rtree.invalids()) + for (TreeRange tree : rtree.rangeIterator()) { tree.addHash(new RowHash(range.right, h2, h2.length)); } List diffs = MerkleTree.difference(ltree, rtree); - assertEquals(Lists.newArrayList(range), diffs); - assertEquals(MerkleTree.FULLY_INCONSISTENT, MerkleTree.differenceHelper(ltree, rtree, new ArrayList<>(), new MerkleTree.TreeDifference(ltree.fullRange.left, ltree.fullRange.right, (byte) 0))); + assertEquals(newArrayList(range), diffs); + assertEquals(MerkleTree.Difference.FULLY_INCONSISTENT, + MerkleTree.differenceHelper(ltree, rtree, new ArrayList<>(), new MerkleTree.TreeRange(ltree.fullRange.left, ltree.fullRange.right, (byte)0))); } /** @@ -494,22 +515,22 @@ public void matchingSmallRange() MerkleTree rtree = new MerkleTree(partitioner, range, RECOMMENDED_DEPTH, 16); rtree.init(); - byte[] h1 = "asdf".getBytes(); - byte[] h2 = "asdf".getBytes(); + byte[] h1 = digest("asdf"); + byte[] h2 = digest("asdf"); // add dummy hashes to both trees - for (TreeRange tree : ltree.invalids()) + for (TreeRange tree : ltree.rangeIterator()) { tree.addHash(new RowHash(range.right, h1, h1.length)); } - for (TreeRange tree : rtree.invalids()) + for (TreeRange tree : rtree.rangeIterator()) { tree.addHash(new RowHash(range.right, h2, h2.length)); } // top level difference() should show no differences - assertEquals(MerkleTree.difference(ltree, rtree), Lists.newArrayList()); + assertEquals(MerkleTree.difference(ltree, rtree), newArrayList()); } /** @@ -533,7 +554,7 @@ byte[] hashed(byte[] val, Integer... depths) while (depth.equals(dstack.peek())) { // consume the stack - hash = Hashable.binaryHash(hstack.pop(), hash); + hash = MerkleTree.xor(hstack.pop(), hash); depth = dstack.pop() - 1; } dstack.push(depth); @@ -643,4 +664,251 @@ private long measureTree(MerkleTree tree, Range fullRange, int depth, Ran tree.hash(fullRange); return ObjectSizes.measureDeep(tree); } + + @Test + public void testEqualTreesSameDepth() throws IOException + { + int seed = makeSeed(); + Trees trees = Trees.make(seed, seed, 3, 3); + testDifferences(trees, Collections.emptyList()); + } + + @Test + public void testEqualTreesDifferentDepth() throws IOException + { + int seed = makeSeed(); + Trees trees = Trees.make(seed, seed, 2, 3); + testDifferences(trees, Collections.emptyList()); + } + + @Test + public void testEntirelyDifferentTrees() throws IOException + { + int seed1 = makeSeed(); + int seed2 = seed1 * 32; + Trees trees = Trees.make(seed1, seed2, 3, 3); + testDifferences(trees, newArrayList(makeTreeRange(0, 16, 0))); + } + + @Test + public void testDifferentTrees1SameDepth() throws IOException + { + int seed = makeSeed(); + Trees trees = Trees.make(seed, seed, 3, 3); + trees.tree1.get(longToken(1)).addHash(digest("diff_1"), 1); + testDifferences(trees, newArrayList(makeTreeRange(0, 2, 3))); + } + + @Test + public void testDifferentTrees1DifferentDepth() throws IOException + { + int seed = makeSeed(); + Trees trees = Trees.make(seed, seed, 2, 3); + trees.tree1.get(longToken(1)).addHash(digest("diff_1"), 1); + testDifferences(trees, newArrayList(makeTreeRange(0, 4, 2))); + } + + @Test + public void testDifferentTrees2SameDepth() throws IOException + { + int seed = makeSeed(); + Trees trees = Trees.make(seed, seed, 3, 3); + trees.tree1.get(longToken(1)).addHash(digest("diff_1"), 1); + trees.tree2.get(longToken(16)).addHash(digest("diff_16"), 1); + testDifferences(trees, newArrayList(makeTreeRange(0, 2, 3), + makeTreeRange(14, 16, 3))); + } + + @Test + public void testDifferentTrees2DifferentDepth() throws IOException + { + int seed = makeSeed(); + Trees trees = Trees.make(seed, seed, 2, 3); + trees.tree1.get(longToken(1)).addHash(digest("diff_1"), 1); + trees.tree2.get(longToken(16)).addHash(digest("diff_16"), 1); + testDifferences(trees, newArrayList(makeTreeRange(0, 4, 2), + makeTreeRange(12, 16, 2))); + } + + @Test + public void testDifferentTrees3SameDepth() throws IOException + { + int seed = makeSeed(); + Trees trees = Trees.make(seed, seed, 3, 3); + trees.tree1.get(longToken(1)).addHash(digest("diff_1"), 1); + trees.tree1.get(longToken(3)).addHash(digest("diff_3"), 1); + testDifferences(trees, newArrayList(makeTreeRange(0, 4, 2))); + } + + @Test + public void testDifferentTrees3Differentepth() throws IOException + { + int seed = makeSeed(); + Trees trees = Trees.make(seed, seed, 2, 3); + trees.tree1.get(longToken(1)).addHash(digest("diff_1"), 1); + trees.tree1.get(longToken(3)).addHash(digest("diff_3"), 1); + testDifferences(trees, newArrayList(makeTreeRange(0, 4, 2))); + } + + @Test + public void testDifferentTrees4SameDepth() throws IOException + { + int seed = makeSeed(); + Trees trees = Trees.make(seed, seed, 3, 3); + trees.tree1.get(longToken(4)).addHash(digest("diff_4"), 1); + trees.tree1.get(longToken(8)).addHash(digest("diff_8"), 1); + trees.tree1.get(longToken(12)).addHash(digest("diff_12"), 1); + trees.tree1.get(longToken(16)).addHash(digest("diff_16"), 1); + testDifferences(trees, newArrayList(makeTreeRange(2, 4, 3), + makeTreeRange(6, 8, 3), + makeTreeRange(10, 12, 3), + makeTreeRange(14, 16, 3))); + } + + @Test + public void testDifferentTrees4DifferentDepth() throws IOException + { + int seed = makeSeed(); + Trees trees = Trees.make(seed, seed, 2, 3); + trees.tree1.get(longToken(4)).addHash(digest("diff_4"), 1); + trees.tree1.get(longToken(8)).addHash(digest("diff_8"), 1); + trees.tree1.get(longToken(12)).addHash(digest("diff_12"), 1); + trees.tree1.get(longToken(16)).addHash(digest("diff_16"), 1); + testDifferences(trees, newArrayList(makeTreeRange(0, 16, 0))); + } + + private static void testDifferences(Trees trees, List expectedDifference) throws IOException + { + MerkleTree mt1 = trees.tree1; + MerkleTree mt2 = trees.tree2; + + assertDiffer(mt1, mt2, expectedDifference); + assertDiffer(mt1, mt2.moveOffHeap(), expectedDifference); + assertDiffer(mt1, cycle(mt2, true), expectedDifference); + assertDiffer(mt1, cycle(mt2, false), expectedDifference); + assertDiffer(mt1, cycle(mt2.moveOffHeap(), true), expectedDifference); + assertDiffer(mt1, cycle(mt2.moveOffHeap(), false), expectedDifference); + + assertDiffer(mt1.moveOffHeap(), mt2, expectedDifference); + assertDiffer(mt1.moveOffHeap(), mt2.moveOffHeap(), expectedDifference); + assertDiffer(mt1.moveOffHeap(), cycle(mt2, true), expectedDifference); + assertDiffer(mt1.moveOffHeap(), cycle(mt2, false), expectedDifference); + assertDiffer(mt1.moveOffHeap(), cycle(mt2.moveOffHeap(), true), expectedDifference); + assertDiffer(mt1.moveOffHeap(), cycle(mt2.moveOffHeap(), false), expectedDifference); + + assertDiffer(cycle(mt1, true), mt2, expectedDifference); + assertDiffer(cycle(mt1, true), mt2.moveOffHeap(), expectedDifference); + assertDiffer(cycle(mt1, true), cycle(mt2, true), expectedDifference); + assertDiffer(cycle(mt1, true), cycle(mt2, false), expectedDifference); + assertDiffer(cycle(mt1, true), cycle(mt2.moveOffHeap(), true), expectedDifference); + assertDiffer(cycle(mt1, true), cycle(mt2.moveOffHeap(), false), expectedDifference); + + assertDiffer(cycle(mt1, false), mt2, expectedDifference); + assertDiffer(cycle(mt1, false), mt2.moveOffHeap(), expectedDifference); + assertDiffer(cycle(mt1, false), cycle(mt2, true), expectedDifference); + assertDiffer(cycle(mt1, false), cycle(mt2, false), expectedDifference); + assertDiffer(cycle(mt1, false), cycle(mt2.moveOffHeap(), true), expectedDifference); + assertDiffer(cycle(mt1, false), cycle(mt2.moveOffHeap(), false), expectedDifference); + + assertDiffer(cycle(mt1.moveOffHeap(), true), mt2, expectedDifference); + assertDiffer(cycle(mt1.moveOffHeap(), true), mt2.moveOffHeap(), expectedDifference); + assertDiffer(cycle(mt1.moveOffHeap(), true), cycle(mt2, true), expectedDifference); + assertDiffer(cycle(mt1.moveOffHeap(), true), cycle(mt2, false), expectedDifference); + assertDiffer(cycle(mt1.moveOffHeap(), true), cycle(mt2.moveOffHeap(), true), expectedDifference); + assertDiffer(cycle(mt1.moveOffHeap(), true), cycle(mt2.moveOffHeap(), false), expectedDifference); + + assertDiffer(cycle(mt1.moveOffHeap(), false), mt2, expectedDifference); + assertDiffer(cycle(mt1.moveOffHeap(), false), mt2.moveOffHeap(), expectedDifference); + assertDiffer(cycle(mt1.moveOffHeap(), false), cycle(mt2, true), expectedDifference); + assertDiffer(cycle(mt1.moveOffHeap(), false), cycle(mt2, false), expectedDifference); + assertDiffer(cycle(mt1.moveOffHeap(), false), cycle(mt2.moveOffHeap(), true), expectedDifference); + assertDiffer(cycle(mt1.moveOffHeap(), false), cycle(mt2.moveOffHeap(), false), expectedDifference); + } + + private static void assertDiffer(MerkleTree mt1, MerkleTree mt2, List expectedDifference) + { + assertEquals(expectedDifference, MerkleTree.difference(mt1, mt2)); + assertEquals(expectedDifference, MerkleTree.difference(mt2, mt1)); + } + + private static Range longTokenRange(long start, long end) + { + return new Range<>(longToken(start), longToken(end)); + } + + private static Murmur3Partitioner.LongToken longToken(long value) + { + return new Murmur3Partitioner.LongToken(value); + } + + private static MerkleTree cycle(MerkleTree mt, boolean offHeapRequested) throws IOException + { + try (DataOutputBuffer output = new DataOutputBuffer()) + { + mt.serialize(output, MessagingService.current_version); + + try (DataInputBuffer input = new DataInputBuffer(output.buffer(false), false)) + { + return MerkleTree.deserialize(input, offHeapRequested, MessagingService.current_version); + } + } + } + + private static MerkleTree makeTree(long start, long end, int depth) + { + MerkleTree mt = new MerkleTree(Murmur3Partitioner.instance, longTokenRange(start, end), depth, Long.MAX_VALUE); + mt.init(); + return mt; + } + + private static TreeRange makeTreeRange(long start, long end, int depth) + { + return new TreeRange(longToken(start), longToken(end), depth); + } + + private static byte[][] makeHashes(int count, int seed) + { + Random random = new Random(seed); + + byte[][] hashes = new byte[count][32]; + for (int i = 0; i < count; i++) + random.nextBytes(hashes[i]); + return hashes; + } + + private static int makeSeed() + { + int seed = (int) System.currentTimeMillis(); + System.out.println("Using seed " + seed); + return seed; + } + + private static class Trees + { + MerkleTree tree1; + MerkleTree tree2; + + Trees(MerkleTree tree1, MerkleTree tree2) + { + this.tree1 = tree1; + this.tree2 = tree2; + } + + static Trees make(int hashes1seed, int hashes2seed, int tree1depth, int tree2depth) + { + byte[][] hashes1 = makeHashes(16, hashes1seed); + byte[][] hashes2 = makeHashes(16, hashes2seed); + + MerkleTree tree1 = makeTree(0, 16, tree1depth); + MerkleTree tree2 = makeTree(0, 16, tree2depth); + + for (int tok = 1; tok <= 16; tok++) + { + tree1.get(longToken(tok)).addHash(hashes1[tok - 1], 1); + tree2.get(longToken(tok)).addHash(hashes2[tok - 1], 1); + } + + return new Trees(tree1, tree2); + } + } } diff --git a/test/unit/org/apache/cassandra/utils/MerkleTreesTest.java b/test/unit/org/apache/cassandra/utils/MerkleTreesTest.java index b40f6c437c64..9e70c2048455 100644 --- a/test/unit/org/apache/cassandra/utils/MerkleTreesTest.java +++ b/test/unit/org/apache/cassandra/utils/MerkleTreesTest.java @@ -34,7 +34,6 @@ import org.apache.cassandra.io.util.DataOutputBuffer; import org.apache.cassandra.net.MessagingService; import org.apache.cassandra.service.StorageService; -import org.apache.cassandra.utils.MerkleTree.Hashable; import org.apache.cassandra.utils.MerkleTree.RowHash; import org.apache.cassandra.utils.MerkleTree.TreeRange; import org.apache.cassandra.utils.MerkleTrees.TreeRangeIterator; @@ -43,7 +42,12 @@ public class MerkleTreesTest { - public static byte[] DUMMY = "blah".getBytes(); + private static final byte[] DUMMY = digest("dummy"); + + private static byte[] digest(String string) + { + return HashingUtils.newMessageDigest("SHA-256").digest(string.getBytes()); + } /** * If a test assumes that the tree is 8 units wide, then it should set this value @@ -193,7 +197,7 @@ public void testInvalids() Iterator ranges; // (zero, zero] - ranges = mts.invalids(); + ranges = mts.rangeIterator(); assertEquals(new Range<>(tok(-1), tok(-1)), ranges.next()); assertFalse(ranges.hasNext()); @@ -203,7 +207,7 @@ public void testInvalids() mts.split(tok(6)); mts.split(tok(3)); mts.split(tok(5)); - ranges = mts.invalids(); + ranges = mts.rangeIterator(); assertEquals(new Range<>(tok(6), tok(-1)), ranges.next()); assertEquals(new Range<>(tok(-1), tok(2)), ranges.next()); assertEquals(new Range<>(tok(2), tok(3)), ranges.next()); @@ -245,11 +249,6 @@ public void testHashPartial() // (zero,two] (two,four] (four, zero] mts.split(tok(4)); mts.split(tok(2)); - assertNull(mts.hash(left)); - assertNull(mts.hash(partial)); - assertNull(mts.hash(right)); - assertNull(mts.hash(linvalid)); - assertNull(mts.hash(rinvalid)); // validate the range mts.get(tok(2)).hash(val); @@ -280,10 +279,6 @@ public void testHashInner() mts.split(tok(2)); mts.split(tok(6)); mts.split(tok(1)); - assertNull(mts.hash(full)); - assertNull(mts.hash(lchild)); - assertNull(mts.hash(rchild)); - assertNull(mts.hash(invalid)); // validate the range mts.get(tok(1)).hash(val); @@ -315,9 +310,6 @@ public void testHashDegenerate() mts.split(tok(4)); mts.split(tok(2)); mts.split(tok(1)); - assertNull(mts.hash(full)); - assertNull(mts.hash(childfull)); - assertNull(mts.hash(invalid)); // validate the range mts.get(tok(1)).hash(val); @@ -349,7 +341,7 @@ public void testHashRandom() } // validate the tree - TreeRangeIterator ranges = mts.invalids(); + TreeRangeIterator ranges = mts.rangeIterator(); for (TreeRange range : ranges) range.addHash(new RowHash(range.right, new byte[0], 0)); @@ -378,13 +370,16 @@ public void testValidateTree() mts.split(tok(6)); mts.split(tok(10)); - ranges = mts.invalids(); - ranges.next().addAll(new HIterator(2, 4)); // (-1,4]: depth 2 - ranges.next().addAll(new HIterator(6)); // (4,6] - ranges.next().addAll(new HIterator(8)); // (6,8] + int seed = 123456789; + + Random random1 = new Random(seed); + ranges = mts.rangeIterator(); + ranges.next().addAll(new HIterator(random1, 2, 4)); // (-1,4]: depth 2 + ranges.next().addAll(new HIterator(random1, 6)); // (4,6] + ranges.next().addAll(new HIterator(random1, 8)); // (6,8] ranges.next().addAll(new HIterator(/*empty*/ new int[0])); // (8,10] - ranges.next().addAll(new HIterator(12)); // (10,12] - ranges.next().addAll(new HIterator(14, -1)); // (12,-1]: depth 2 + ranges.next().addAll(new HIterator(random1, 12)); // (10,12] + ranges.next().addAll(new HIterator(random1, 14, -1)); // (12,-1]: depth 2 mts2.split(tok(8)); @@ -395,15 +390,16 @@ public void testValidateTree() mts2.split(tok(9)); mts2.split(tok(11)); - ranges = mts2.invalids(); - ranges.next().addAll(new HIterator(2)); // (-1,2] - ranges.next().addAll(new HIterator(4)); // (2,4] - ranges.next().addAll(new HIterator(6, 8)); // (4,8]: depth 2 + Random random2 = new Random(seed); + ranges = mts2.rangeIterator(); + ranges.next().addAll(new HIterator(random2, 2)); // (-1,2] + ranges.next().addAll(new HIterator(random2, 4)); // (2,4] + ranges.next().addAll(new HIterator(random2, 6, 8)); // (4,8]: depth 2 ranges.next().addAll(new HIterator(/*empty*/ new int[0])); // (8,9] ranges.next().addAll(new HIterator(/*empty*/ new int[0])); // (9,10] ranges.next().addAll(new HIterator(/*empty*/ new int[0])); // (10,11]: depth 4 - ranges.next().addAll(new HIterator(12)); // (11,12]: depth 4 - ranges.next().addAll(new HIterator(14, -1)); // (12,-1]: depth 2 + ranges.next().addAll(new HIterator(random2, 12)); // (11,12]: depth 4 + ranges.next().addAll(new HIterator(random2, 14, -1)); // (12,-1]: depth 2 byte[] mthash = mts.hash(full); byte[] mt2hash = mts2.hash(full); @@ -425,7 +421,7 @@ public void testSerialization() throws Exception // populate and validate the tree mts.init(); - for (TreeRange range : mts.invalids()) + for (TreeRange range : mts.rangeIterator()) range.addAll(new HIterator(range.right)); byte[] initialhash = mts.hash(first); @@ -456,11 +452,15 @@ public void testDifference() mts.init(); mts2.init(); + int seed = 123456789; // add dummy hashes to both trees - for (TreeRange range : mts.invalids()) - range.addAll(new HIterator(range.right)); - for (TreeRange range : mts2.invalids()) - range.addAll(new HIterator(range.right)); + Random random1 = new Random(seed); + for (TreeRange range : mts.rangeIterator()) + range.addAll(new HIterator(random1, range.right)); + + Random random2 = new Random(seed); + for (TreeRange range : mts2.rangeIterator()) + range.addAll(new HIterator(random2, range.right)); TreeRange leftmost = null; TreeRange middle = null; @@ -468,14 +468,14 @@ public void testDifference() mts.maxsize(fullRange(), maxsize + 2); // give some room for splitting // split the leftmost - Iterator ranges = mts.invalids(); + Iterator ranges = mts.rangeIterator(); leftmost = ranges.next(); mts.split(leftmost.right); // set the hashes for the leaf of the created split middle = mts.get(leftmost.right); - middle.hash("arbitrary!".getBytes()); - mts.get(partitioner.midpoint(leftmost.left, leftmost.right)).hash("even more arbitrary!".getBytes()); + middle.hash(digest("arbitrary!")); + mts.get(partitioner.midpoint(leftmost.left, leftmost.right)).hash(digest("even more arbitrary!")); // trees should disagree for (leftmost.left, middle.right] List> diffs = MerkleTrees.difference(mts, mts2); @@ -504,7 +504,7 @@ byte[] hashed(byte[] val, Integer... depths) while (depth.equals(dstack.peek())) { // consume the stack - hash = Hashable.binaryHash(hstack.pop(), hash); + hash = MerkleTree.xor(hstack.pop(), hash); depth = dstack.pop()-1; } dstack.push(depth); @@ -516,25 +516,47 @@ byte[] hashed(byte[] val, Integer... depths) public static class HIterator extends AbstractIterator { - private Iterator tokens; + private final Random random; + private final Iterator tokens; - public HIterator(int... tokens) + HIterator(int... tokens) { - List tlist = new LinkedList(); + this(new Random(), tokens); + } + + HIterator(Random random, int... tokens) + { + List tlist = new ArrayList<>(tokens.length); for (int token : tokens) tlist.add(tok(token)); this.tokens = tlist.iterator(); + this.random = random; } public HIterator(Token... tokens) { - this.tokens = Arrays.asList(tokens).iterator(); + this(new Random(), tokens); + } + + HIterator(Random random, Token... tokens) + { + this(random, Arrays.asList(tokens).iterator()); + } + + private HIterator(Random random, Iterator tokens) + { + this.random = random; + this.tokens = tokens; } public RowHash computeNext() { if (tokens.hasNext()) - return new RowHash(tokens.next(), DUMMY, DUMMY.length); + { + byte[] digest = new byte[32]; + random.nextBytes(digest); + return new RowHash(tokens.next(), digest, 12345L); + } return endOfData(); } } diff --git a/test/unit/org/apache/cassandra/utils/NanoTimeToCurrentTimeMillisTest.java b/test/unit/org/apache/cassandra/utils/MonotonicClockTest.java similarity index 87% rename from test/unit/org/apache/cassandra/utils/NanoTimeToCurrentTimeMillisTest.java rename to test/unit/org/apache/cassandra/utils/MonotonicClockTest.java index 25aeada117fe..b2891a9950e2 100644 --- a/test/unit/org/apache/cassandra/utils/NanoTimeToCurrentTimeMillisTest.java +++ b/test/unit/org/apache/cassandra/utils/MonotonicClockTest.java @@ -17,11 +17,12 @@ */ package org.apache.cassandra.utils; +import static org.apache.cassandra.utils.MonotonicClock.approxTime; import static org.junit.Assert.*; import org.junit.Test; -public class NanoTimeToCurrentTimeMillisTest +public class MonotonicClockTest { @Test public void testTimestampOrdering() throws Exception @@ -34,12 +35,12 @@ public void testTimestampOrdering() throws Exception now = Math.max(now, System.currentTimeMillis()); if (ii % 10000 == 0) { - NanoTimeToCurrentTimeMillis.updateNow(); + ((MonotonicClock.SampledClock) approxTime).refreshNow(); Thread.sleep(1); } nowNanos = Math.max(nowNanos, System.nanoTime()); - long convertedNow = NanoTimeToCurrentTimeMillis.convert(nowNanos); + long convertedNow = approxTime.translate().toMillisSinceEpoch(nowNanos); int maxDiff = FBUtilities.isWindows ? 15 : 1; assertTrue("convertedNow = " + convertedNow + " lastConverted = " + lastConverted + " in iteration " + ii, diff --git a/test/unit/org/apache/cassandra/utils/memory/BufferPoolTest.java b/test/unit/org/apache/cassandra/utils/memory/BufferPoolTest.java index 74889a1fa2b8..62cb33b91ba5 100644 --- a/test/unit/org/apache/cassandra/utils/memory/BufferPoolTest.java +++ b/test/unit/org/apache/cassandra/utils/memory/BufferPoolTest.java @@ -53,7 +53,7 @@ public void setUp() @After public void cleanUp() { - BufferPool.reset(); + BufferPool.unsafeReset(); } @Test @@ -66,12 +66,12 @@ public void testGetPut() throws InterruptedException assertEquals(size, buffer.capacity()); assertEquals(true, buffer.isDirect()); - BufferPool.Chunk chunk = BufferPool.currentChunk(); + BufferPool.Chunk chunk = BufferPool.unsafeCurrentChunk(); assertNotNull(chunk); assertEquals(BufferPool.GlobalPool.MACRO_CHUNK_SIZE, BufferPool.sizeInBytes()); BufferPool.put(buffer); - assertEquals(null, BufferPool.currentChunk()); + assertEquals(null, BufferPool.unsafeCurrentChunk()); assertEquals(BufferPool.GlobalPool.MACRO_CHUNK_SIZE, BufferPool.sizeInBytes()); } @@ -81,7 +81,7 @@ public void testPageAligned() { final int size = 1024; for (int i = size; - i <= BufferPool.CHUNK_SIZE; + i <= BufferPool.NORMAL_CHUNK_SIZE; i += size) { checkPageAligned(i); @@ -115,14 +115,14 @@ public void testDifferentSizes() throws InterruptedException assertNotNull(buffer2); assertEquals(size2, buffer2.capacity()); - BufferPool.Chunk chunk = BufferPool.currentChunk(); + BufferPool.Chunk chunk = BufferPool.unsafeCurrentChunk(); assertNotNull(chunk); assertEquals(BufferPool.GlobalPool.MACRO_CHUNK_SIZE, BufferPool.sizeInBytes()); BufferPool.put(buffer1); BufferPool.put(buffer2); - assertEquals(null, BufferPool.currentChunk()); + assertEquals(null, BufferPool.unsafeCurrentChunk()); assertEquals(BufferPool.GlobalPool.MACRO_CHUNK_SIZE, BufferPool.sizeInBytes()); } @@ -165,7 +165,7 @@ public void testMaxMemoryExceeded_SmallerThanChunkSize() @Test public void testRecycle() { - requestUpToSize(RandomAccessReader.DEFAULT_BUFFER_SIZE, 3 * BufferPool.CHUNK_SIZE); + requestUpToSize(RandomAccessReader.DEFAULT_BUFFER_SIZE, 3 * BufferPool.NORMAL_CHUNK_SIZE); } private void requestDoubleMaxMemory() @@ -192,13 +192,12 @@ private void requestUpToSize(int bufferSize, int totalSize) for (ByteBuffer buffer : buffers) BufferPool.put(buffer); - } @Test public void testBigRequest() { - final int size = BufferPool.CHUNK_SIZE + 1; + final int size = BufferPool.NORMAL_CHUNK_SIZE + 1; ByteBuffer buffer = BufferPool.get(size); assertNotNull(buffer); @@ -210,30 +209,30 @@ public void testBigRequest() public void testFillUpChunks() { final int size = RandomAccessReader.DEFAULT_BUFFER_SIZE; - final int numBuffers = BufferPool.CHUNK_SIZE / size; + final int numBuffers = BufferPool.NORMAL_CHUNK_SIZE / size; List buffers1 = new ArrayList<>(numBuffers); List buffers2 = new ArrayList<>(numBuffers); for (int i = 0; i < numBuffers; i++) buffers1.add(BufferPool.get(size)); - BufferPool.Chunk chunk1 = BufferPool.currentChunk(); + BufferPool.Chunk chunk1 = BufferPool.unsafeCurrentChunk(); assertNotNull(chunk1); for (int i = 0; i < numBuffers; i++) buffers2.add(BufferPool.get(size)); - assertEquals(2, BufferPool.numChunks()); + assertEquals(2, BufferPool.unsafeNumChunks()); for (ByteBuffer buffer : buffers1) BufferPool.put(buffer); - assertEquals(1, BufferPool.numChunks()); + assertEquals(1, BufferPool.unsafeNumChunks()); for (ByteBuffer buffer : buffers2) BufferPool.put(buffer); - assertEquals(0, BufferPool.numChunks()); + assertEquals(0, BufferPool.unsafeNumChunks()); buffers2.clear(); } @@ -242,7 +241,7 @@ public void testFillUpChunks() public void testOutOfOrderFrees() { final int size = 4096; - final int maxFreeSlots = BufferPool.CHUNK_SIZE / size; + final int maxFreeSlots = BufferPool.NORMAL_CHUNK_SIZE / size; final int[] idxs = new int[maxFreeSlots]; for (int i = 0; i < maxFreeSlots; i++) @@ -255,7 +254,7 @@ public void testOutOfOrderFrees() public void testInOrderFrees() { final int size = 4096; - final int maxFreeSlots = BufferPool.CHUNK_SIZE / size; + final int maxFreeSlots = BufferPool.NORMAL_CHUNK_SIZE / size; final int[] idxs = new int[maxFreeSlots]; for (int i = 0; i < maxFreeSlots; i++) @@ -269,23 +268,23 @@ public void testRandomFrees() { doTestRandomFrees(12345567878L); - BufferPool.reset(); + BufferPool.unsafeReset(); doTestRandomFrees(20452249587L); - BufferPool.reset(); + BufferPool.unsafeReset(); doTestRandomFrees(82457252948L); - BufferPool.reset(); + BufferPool.unsafeReset(); doTestRandomFrees(98759284579L); - BufferPool.reset(); + BufferPool.unsafeReset(); doTestRandomFrees(19475257244L); } private void doTestRandomFrees(long seed) { final int size = 4096; - final int maxFreeSlots = BufferPool.CHUNK_SIZE / size; + final int maxFreeSlots = BufferPool.NORMAL_CHUNK_SIZE / size; final int[] idxs = new int[maxFreeSlots]; for (int i = 0; i < maxFreeSlots; i++) @@ -312,10 +311,10 @@ private void doTestFrees(final int size, final int maxFreeSlots, final int[] toR buffers.add(BufferPool.get(size)); } - BufferPool.Chunk chunk = BufferPool.currentChunk(); + BufferPool.Chunk chunk = BufferPool.unsafeCurrentChunk(); assertFalse(chunk.isFree()); - int freeSize = BufferPool.CHUNK_SIZE - maxFreeSlots * size; + int freeSize = BufferPool.NORMAL_CHUNK_SIZE - maxFreeSlots * size; assertEquals(freeSize, chunk.free()); for (int i : toReleaseIdxs) @@ -352,33 +351,34 @@ public void testDifferentSizeBuffersOnOneChunk() assertTrue(buffer.capacity() >= sizes[i]); buffers.add(buffer); - sum += BufferPool.currentChunk().roundUp(buffer.capacity()); + sum += BufferPool.unsafeCurrentChunk().roundUp(buffer.capacity()); } // else the test will fail, adjust sizes as required assertTrue(sum <= BufferPool.GlobalPool.MACRO_CHUNK_SIZE); - BufferPool.Chunk chunk = BufferPool.currentChunk(); + BufferPool.Chunk chunk = BufferPool.unsafeCurrentChunk(); assertNotNull(chunk); Random rnd = new Random(); rnd.setSeed(298347529L); - while (!buffers.isEmpty()) + while (buffers.size() > 1) { int index = rnd.nextInt(buffers.size()); ByteBuffer buffer = buffers.remove(index); BufferPool.put(buffer); } + BufferPool.put(buffers.remove(0)); - assertEquals(null, BufferPool.currentChunk()); + assertEquals(null, BufferPool.unsafeCurrentChunk()); assertEquals(0, chunk.free()); } @Test public void testChunkExhausted() { - final int size = BufferPool.CHUNK_SIZE / 64; // 1kbit + final int size = BufferPool.NORMAL_CHUNK_SIZE / 64; // 1kbit int[] sizes = new int[128]; Arrays.fill(sizes, size); @@ -397,7 +397,7 @@ public void testChunkExhausted() // else the test will fail, adjust sizes as required assertTrue(sum <= BufferPool.GlobalPool.MACRO_CHUNK_SIZE); - BufferPool.Chunk chunk = BufferPool.currentChunk(); + BufferPool.Chunk chunk = BufferPool.unsafeCurrentChunk(); assertNotNull(chunk); for (int i = 0; i < sizes.length; i++) @@ -405,7 +405,7 @@ public void testChunkExhausted() BufferPool.put(buffers.get(i)); } - assertEquals(null, BufferPool.currentChunk()); + assertEquals(null, BufferPool.unsafeCurrentChunk()); assertEquals(0, chunk.free()); } @@ -505,9 +505,9 @@ private void checkBuffer(int size) ByteBuffer buffer = BufferPool.get(size); assertEquals(size, buffer.capacity()); - if (size > 0 && size < BufferPool.CHUNK_SIZE) + if (size > 0 && size < BufferPool.NORMAL_CHUNK_SIZE) { - BufferPool.Chunk chunk = BufferPool.currentChunk(); + BufferPool.Chunk chunk = BufferPool.unsafeCurrentChunk(); assertNotNull(chunk); assertEquals(chunk.capacity(), chunk.free() + chunk.roundUp(size)); } @@ -552,7 +552,7 @@ private void checkBufferWithGivenSlots(int size, long freeSlots) ByteBuffer buffer = BufferPool.get(size); // now get the current chunk and override the free slots mask - BufferPool.Chunk chunk = BufferPool.currentChunk(); + BufferPool.Chunk chunk = BufferPool.unsafeCurrentChunk(); assertNotNull(chunk); long oldFreeSlots = chunk.setFreeSlots(freeSlots); @@ -561,7 +561,7 @@ private void checkBufferWithGivenSlots(int size, long freeSlots) assertEquals(size, buffer.capacity()); BufferPool.put(buffer2); - // reset the free slots + // unsafeReset the free slots chunk.setFreeSlots(oldFreeSlots); BufferPool.put(buffer); } @@ -587,22 +587,22 @@ public void testBufferPoolDisabled() BufferPool.DISABLED = true; BufferPool.ALLOCATE_ON_HEAP_WHEN_EXAHUSTED = true; ByteBuffer buffer = BufferPool.get(1024); - assertEquals(0, BufferPool.numChunks()); + assertEquals(0, BufferPool.unsafeNumChunks()); assertNotNull(buffer); assertEquals(1024, buffer.capacity()); assertFalse(buffer.isDirect()); assertNotNull(buffer.array()); BufferPool.put(buffer); - assertEquals(0, BufferPool.numChunks()); + assertEquals(0, BufferPool.unsafeNumChunks()); BufferPool.ALLOCATE_ON_HEAP_WHEN_EXAHUSTED = false; buffer = BufferPool.get(1024); - assertEquals(0, BufferPool.numChunks()); + assertEquals(0, BufferPool.unsafeNumChunks()); assertNotNull(buffer); assertEquals(1024, buffer.capacity()); assertTrue(buffer.isDirect()); BufferPool.put(buffer); - assertEquals(0, BufferPool.numChunks()); + assertEquals(0, BufferPool.unsafeNumChunks()); // clean-up BufferPool.DISABLED = false; @@ -794,10 +794,10 @@ private void doMultipleThreadsReleaseBuffers(final int threadCount, final int .. buffers[i] = BufferPool.get(sizes[i]); assertNotNull(buffers[i]); assertEquals(sizes[i], buffers[i].capacity()); - sum += BufferPool.currentChunk().roundUp(buffers[i].capacity()); + sum += BufferPool.unsafeCurrentChunk().roundUp(buffers[i].capacity()); } - final BufferPool.Chunk chunk = BufferPool.currentChunk(); + final BufferPool.Chunk chunk = BufferPool.unsafeCurrentChunk(); assertNotNull(chunk); assertFalse(chunk.isFree()); @@ -819,7 +819,7 @@ public void run() { try { - assertNotSame(chunk, BufferPool.currentChunk()); + assertNotSame(chunk, BufferPool.unsafeCurrentChunk()); BufferPool.put(buffer); } catch (AssertionError ex) @@ -849,7 +849,7 @@ public void run() System.gc(); System.gc(); - assertTrue(BufferPool.currentChunk().isFree()); + assertTrue(BufferPool.unsafeCurrentChunk().isFree()); //make sure the main thread can still allocate buffers ByteBuffer buffer = BufferPool.get(sizes[0]); diff --git a/test/unit/org/apache/cassandra/utils/vint/VIntCodingTest.java b/test/unit/org/apache/cassandra/utils/vint/VIntCodingTest.java index 2189be358f6c..c7c3324f7f50 100644 --- a/test/unit/org/apache/cassandra/utils/vint/VIntCodingTest.java +++ b/test/unit/org/apache/cassandra/utils/vint/VIntCodingTest.java @@ -22,10 +22,8 @@ import java.io.DataOutputStream; import java.io.IOException; -import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; import org.apache.cassandra.io.util.DataOutputBuffer; -import org.apache.cassandra.net.async.ByteBufDataOutputPlus; import org.junit.Test; @@ -92,9 +90,11 @@ public void testOneByteCapacity() throws Exception { public void testByteBufWithNegativeNumber() throws IOException { int i = -1231238694; - ByteBuf buf = Unpooled.buffer(8); - VIntCoding.writeUnsignedVInt(i, new ByteBufDataOutputPlus(buf)); - long result = VIntCoding.readUnsignedVInt(buf); - Assert.assertEquals(i, result); + try (DataOutputBuffer out = new DataOutputBuffer()) + { + VIntCoding.writeUnsignedVInt(i, out); + long result = VIntCoding.getUnsignedVInt(out.buffer(), 0); + Assert.assertEquals(i, result); + } } } diff --git a/tools/stress/src/org/apache/cassandra/stress/generate/SeedManager.java b/tools/stress/src/org/apache/cassandra/stress/generate/SeedManager.java index 4eeb47d59100..6874fba170c0 100644 --- a/tools/stress/src/org/apache/cassandra/stress/generate/SeedManager.java +++ b/tools/stress/src/org/apache/cassandra/stress/generate/SeedManager.java @@ -228,7 +228,6 @@ void finishWrite(Seed seed) private class LookbackReadGenerator extends Generator { - final Distribution lookback; public LookbackReadGenerator(Distribution lookback) diff --git a/tools/stress/src/org/apache/cassandra/stress/settings/SettingsTransport.java b/tools/stress/src/org/apache/cassandra/stress/settings/SettingsTransport.java index 9b8eaa0efed5..4ea4cd2f4221 100644 --- a/tools/stress/src/org/apache/cassandra/stress/settings/SettingsTransport.java +++ b/tools/stress/src/org/apache/cassandra/stress/settings/SettingsTransport.java @@ -43,23 +43,26 @@ public EncryptionOptions getEncryptionOptions() EncryptionOptions encOptions = new EncryptionOptions(); if (options.trustStore.present()) { - encOptions.enabled = true; - encOptions.truststore = options.trustStore.value(); - encOptions.truststore_password = options.trustStorePw.value(); + encOptions = encOptions + .withEnabled(true) + .withTrustStore(options.trustStore.value()) + .withTrustStorePassword(options.trustStorePw.value()) + .withAlgorithm(options.alg.value()) + .withProtocol(options.protocol.value()) + .withCipherSuites(options.ciphers.value().split(",")); if (options.keyStore.present()) { - encOptions.keystore = options.keyStore.value(); - encOptions.keystore_password = options.keyStorePw.value(); + encOptions = encOptions + .withKeyStore(options.keyStore.value()) + .withKeyStorePassword(options.keyStorePw.value()); } else { // mandatory for SSLFactory.createSSLContext(), see CASSANDRA-9325 - encOptions.keystore = encOptions.truststore; - encOptions.keystore_password = encOptions.truststore_password; + encOptions = encOptions + .withKeyStore(encOptions.truststore) + .withKeyStorePassword(encOptions.truststore_password); } - encOptions.algorithm = options.alg.value(); - encOptions.protocol = options.protocol.value(); - encOptions.cipher_suites = options.ciphers.value().split(","); } return encOptions; } diff --git a/tools/stress/src/org/apache/cassandra/stress/util/JavaDriverClient.java b/tools/stress/src/org/apache/cassandra/stress/util/JavaDriverClient.java index fbcab4b95741..36361f7de2f7 100644 --- a/tools/stress/src/org/apache/cassandra/stress/util/JavaDriverClient.java +++ b/tools/stress/src/org/apache/cassandra/stress/util/JavaDriverClient.java @@ -148,7 +148,7 @@ public void connect(ProtocolOptions.Compression compression) throws Exception sslContext = SSLFactory.createSSLContext(encryptionOptions, true); SSLOptions sslOptions = JdkSSLOptions.builder() .withSSLContext(sslContext) - .withCipherSuites(encryptionOptions.cipher_suites).build(); + .withCipherSuites(encryptionOptions.cipher_suites.toArray(new String[0])).build(); clusterBuilder.withSSL(sslOptions); }