Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions cpp/command/sandbox.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -273,8 +273,9 @@ int MainCmds::sandbox() {
if(!builder)
throw StringError("sandbox: failed to create TensorRT builder");

const auto explicitBatch = 1U << static_cast<uint32_t>(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);
auto network = unique_ptr<nvinfer1::INetworkDefinition>(builder->createNetworkV2(explicitBatch));
// TensorRT 11 networks are always explicit-batch and strongly typed; createNetworkV2 takes no
// flags (the kEXPLICIT_BATCH NetworkDefinitionCreationFlag was removed).
auto network = unique_ptr<nvinfer1::INetworkDefinition>(builder->createNetworkV2(0U));
if(!network)
throw StringError("sandbox: failed to create TensorRT network");

Expand Down
195 changes: 195 additions & 0 deletions cpp/neuralnet/onnxmodelbuilder.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
#include "../neuralnet/onnxmodelbuilder.h"

#include <cmath>
#include <cstdint>
#include <cstring>
#include <map>
#include <set>
#include <unordered_map>
#include <utility>

#include "../core/global.h"
#include "../core/test.h"
Expand Down Expand Up @@ -806,12 +812,190 @@ struct Builder {

namespace OnnxModelBuilder {

// ---- FP16 conversion (for the TensorRT 11 strongly-typed backend) ----
//
// The builder above always emits an FP32 graph. TensorRT 11 networks are strongly typed, so the
// only way to run the trunk in FP16 is to make the ONNX graph itself FP16. This pass rewrites a
// finished FP32 graph into mixed precision: every node runs FP16 except the numerically-sensitive
// nodes named in keepFP32 (RMSNorm square/reduce/sqrt reductions + the trunk tip and policy/value
// heads), which stay FP32. Graph inputs and outputs stay FP32 (KataGo feeds/reads FP32 buffers),
// so casts are inserted wherever an edge crosses an FP16<->FP32 boundary, and float weight
// initializers consumed only by FP16 nodes are converted to FP16. This reproduces the precision
// policy the old weakly-typed path expressed via setPrecision()+kOBEY_PRECISION_CONSTRAINTS.

// IEEE-754 single -> half (binary16), round-to-nearest-even, with inf/nan/overflow/subnormal handling.
static uint16_t floatToHalf(float f) {
uint32_t x;
std::memcpy(&x, &f, sizeof(x));
const uint32_t sign = (x >> 16) & 0x8000u;
const uint32_t mant = x & 0x007fffffu;
const int32_t rawExp = (int32_t)((x >> 23) & 0xffu);
if(rawExp == 0xff) // Inf / NaN preserved as-is
return (uint16_t)(sign | (mant != 0 ? 0x7e00u : 0x7c00u));
const int32_t exp = rawExp - 127 + 15; // rebias to half
if(exp >= 0x1f)
// A *finite* value too large for half is clamped to the max finite half (+-65504) rather than
// promoted to Inf. NN weights are never meant to be infinite, and KataGo uses large sentinel
// constants (e.g. the 1e9 attention off-board mask bias) that must stay finite in FP16 - an Inf
// there yields 0*Inf = NaN in the attention softmax. Clamping preserves the intended semantics
// (a huge-but-finite negative bias still drives softmax to ~0). Matches onnxconverter-common.
return (uint16_t)(sign | 0x7bffu);
if(exp <= 0) { // subnormal half or zero
if(exp < -10)
return (uint16_t)sign; // too small even for a subnormal
const uint32_t m = mant | 0x00800000u; // restore implicit leading 1
const int shift = 14 - exp; // in [14, 24]
uint32_t half = m >> shift;
const uint32_t rem = m & ((1u << shift) - 1u);
const uint32_t halfway = 1u << (shift - 1);
if(rem > halfway || (rem == halfway && (half & 1u)))
half += 1; // round to nearest even (may carry up to the smallest normal, which is correct)
return (uint16_t)(sign | half);
}
// normalized
const uint16_t half = (uint16_t)(sign | (uint32_t)(exp << 10) | (mant >> 13));
const uint32_t rem = mant & 0x1fffu;
// round to nearest even; a mantissa carry naturally propagates into the exponent field.
if(rem > 0x1000u || (rem == 0x1000u && (half & 1u)))
return (uint16_t)(half + 1);
return half;
}

static void convertInitializerToFP16(onnx::TensorProto* init) {
const int n = init->float_data_size();
std::string raw;
raw.resize((size_t)n * 2);
for(int i = 0; i < n; i++) {
const uint16_t h = floatToHalf(init->float_data(i));
raw[(size_t)2 * i] = (char)(h & 0xffu); // little-endian
raw[(size_t)2 * i + 1] = (char)((h >> 8) & 0xffu);
}
init->clear_float_data();
init->set_data_type(onnx::TensorProto::FLOAT16);
init->set_raw_data(raw);
}

// Rewrite an all-FP32 graph into mixed FP16/FP32 in place. keepFP32 holds the node *names* that must
// stay FP32; every other node becomes FP16. Casts are inserted (in topological position) on any edge
// whose producer/consumer precision differ, and FP16-only float initializers are converted to FP16.
static void convertGraphToFloat16(onnx::GraphProto* graph, const std::set<std::string>& keepFP32) {
using std::string;

auto nodeIsFP16 = [&](const onnx::NodeProto& n) { return keepFP32.count(n.name()) == 0; };

std::set<string> graphInputNames;
for(const auto& vi : graph->input())
graphInputNames.insert(vi.name());

// tensor name -> producing node index (every node here has a single, uniquely-named output)
std::unordered_map<string, int> producer;
for(int i = 0; i < graph->node_size(); i++)
for(const string& o : graph->node(i).output())
producer[o] = i;

// Classify initializers: FLOAT ones are candidates for FP16; INT64 (axes/shapes) are left alone.
std::unordered_map<string, onnx::TensorProto*> initByName;
std::set<string> floatInitNames;
for(int i = 0; i < graph->initializer_size(); i++) {
onnx::TensorProto* init = graph->mutable_initializer(i);
initByName[init->name()] = init;
if(init->data_type() == onnx::TensorProto::FLOAT)
floatInitNames.insert(init->name());
}

// A float initializer becomes FP16 iff every node consuming it is FP16 (otherwise keep it FP32 and
// let cast insertion handle any FP16 consumer). Most weights have exactly one consumer.
std::unordered_map<string, bool> initSawFP16, initSawFP32;
for(int i = 0; i < graph->node_size(); i++) {
const bool fp16 = nodeIsFP16(graph->node(i));
for(const string& in : graph->node(i).input())
if(floatInitNames.count(in))
(fp16 ? initSawFP16 : initSawFP32)[in] = true;
}
std::set<string> initIsFP16;
for(const string& name : floatInitNames) {
if(initSawFP16.count(name) && !initSawFP32.count(name)) {
convertInitializerToFP16(initByName[name]);
initIsFP16.insert(name);
}
}

auto isFloatTensor = [&](const string& name) -> bool {
if(graphInputNames.count(name))
return true; // all KataGo graph inputs are FLOAT
auto it = initByName.find(name);
if(it != initByName.end())
return floatInitNames.count(name) > 0; // INT64 initializers are not float
return producer.count(name) > 0; // node outputs in this graph are all float
};
auto tensorIsFP16 = [&](const string& name) -> bool {
if(graphInputNames.count(name))
return false;
if(initByName.count(name))
return initIsFP16.count(name) > 0;
auto it = producer.find(name);
return it != producer.end() && nodeIsFP16(graph->node(it->second));
};

// Rebuild the node list, emitting any required Cast nodes just before the node that needs them so
// the result stays topologically ordered. Casts are cached by (source tensor, target precision).
std::map<std::pair<string, bool>, string> castCache;
int castCounter = 0;
google::protobuf::RepeatedPtrField<onnx::NodeProto> newNodes;
for(int i = 0; i < graph->node_size(); i++) {
const onnx::NodeProto& orig = graph->node(i);
const bool fp16 = nodeIsFP16(orig);
std::vector<string> rewritten;
for(const string& in : orig.input()) {
if(in.empty() || !isFloatTensor(in) || tensorIsFP16(in) == fp16) {
rewritten.push_back(in);
continue;
}
const auto key = std::make_pair(in, fp16);
auto cit = castCache.find(key);
if(cit != castCache.end()) {
rewritten.push_back(cit->second);
continue;
}
const string castOut = in + (fp16 ? "__tofp16_" : "__tofp32_") + Global::intToString(castCounter++);
onnx::NodeProto* c = newNodes.Add();
c->set_op_type("Cast");
c->set_name(castOut + "/cast");
c->add_input(in);
c->add_output(castOut);
onnx::AttributeProto* a = c->add_attribute();
a->set_name("to");
a->set_type(onnx::AttributeProto::INT);
a->set_i(fp16 ? onnx::TensorProto::FLOAT16 : onnx::TensorProto::FLOAT);
castCache[key] = castOut;
rewritten.push_back(castOut);
}
onnx::NodeProto* nn = newNodes.Add();
*nn = orig;
nn->clear_input();
for(const string& in : rewritten)
nn->add_input(in);
}

// Sanity (checked against the original node list, before the swap below): graph outputs must remain
// FP32, since their producers are in keepFP32 and getOutput does a flat FP32 cudaMemcpy of each
// output binding. Fail loudly rather than silently producing garbage if that ever stops holding.
for(const auto& vo : graph->output()) {
auto it = producer.find(vo.name());
if(it != producer.end() && nodeIsFP16(graph->node(it->second)))
throw StringError("OnnxModelBuilder: FP16 conversion left graph output '" + vo.name() + "' in FP16");
}

graph->mutable_node()->Swap(&newNodes);
}

Result build(
const ModelDesc& desc,
int nnXLen,
int nnYLen,
bool requireExactNNLen,
bool transformerNHWC,
bool useFP16,
Logger* logger
) {
if(desc.metaEncoderVersion > 0)
Expand Down Expand Up @@ -1051,6 +1235,17 @@ Result build(

b.recordNodesSince(trunkTipAndHeadStart, b.trunkTipAndHeadNodeNames);

// For TensorRT 11 strongly-typed engines: rewrite the finished FP32 graph into mixed FP16/FP32,
// keeping the RMSNorm reductions and the trunk-tip + heads in FP32 (the same regions the old
// weakly-typed path pinned via setPrecision). Inputs/outputs stay FP32.
if(useFP16) {
std::set<string> keepFP32(b.trunkTipAndHeadNodeNames.begin(), b.trunkTipAndHeadNodeNames.end());
keepFP32.insert(b.rmsNormNodeNames.begin(), b.rmsNormNodeNames.end());
convertGraphToFloat16(graph, keepFP32);
if(logger != NULL)
logger->write("OnnxModelBuilder: converted trunk to FP16 (" + Global::intToString((int)keepFP32.size()) + " nodes kept FP32)");
}

// DEBUG (kept commented out): expose every internal node output as an extra FP32 graph output so the
// backend can dump per-layer activations for FP16-vs-FP32 *numerical* divergence analysis. This is
// complementary to the trtDumpDebugPlanToDir engine dump (which shows fusion structure and boundary
Expand Down
6 changes: 5 additions & 1 deletion cpp/neuralnet/onnxmodelbuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,17 @@ namespace OnnxModelBuilder {
std::vector<std::string> rmsNormNodeNames; // every RMSNorm (transformer + trunk-tip) op
};

// Build a serialized ONNX ModelProto for the given model.
// Build a serialized ONNX ModelProto for the given model. When useFP16 is set, the finished graph
// is rewritten to run the trunk in FP16 with the numerically-sensitive regions (RMSNorm reductions,
// trunk tip, policy/value heads) and the graph inputs/outputs kept in FP32 (see convertGraphToFloat16
// in the .cpp). This is what makes FP16 possible under TensorRT 11's strongly-typed networks.
Result build(
const ModelDesc& desc,
int nnXLen,
int nnYLen,
bool requireExactNNLen,
bool transformerNHWC,
bool useFP16,
Logger* logger
);
}
Expand Down
Loading