diff --git a/src/djcdata/interface/trainData.h b/src/djcdata/interface/trainData.h index 3ef9827..b5d81d6 100644 --- a/src/djcdata/interface/trainData.h +++ b/src/djcdata/interface/trainData.h @@ -229,7 +229,10 @@ class trainData{ void readMetaDataFromFile(const std::string& filename); std::vector getFirstRowsplits()const; - std::vector readShapesAndRowSplitsFromFile(const std::string& filename, bool checkConsistency=true); + // Returns all distinct rowsplit vectors present in this trainData (one per ragged array group). + std::vector> getAllRowsplits()const; + // Returns all distinct rowsplit vectors found in the file (one per ragged array, deduplicated). + std::vector> readShapesAndRowSplitsFromFile(const std::string& filename); void clear(); @@ -315,7 +318,8 @@ class trainData{ void checkFile(FILE *& f, const std::string& filename="")const; - void readRowSplitArray(FILE *&, std::vector &rs, bool check)const; + // Reads one typeContainer's worth of arrays from the file, collecting all distinct rowsplits found. + void readRowSplitArray(FILE *&, std::vector> &all_rs)const; std::vector > getShapes(const typeContainer& a)const; diff --git a/src/djcdata/interface/trainDataGenerator.h b/src/djcdata/interface/trainDataGenerator.h index 3a2130f..d3e55d6 100644 --- a/src/djcdata/interface/trainDataGenerator.h +++ b/src/djcdata/interface/trainDataGenerator.h @@ -127,7 +127,8 @@ class trainDataGenerator{ std::vector orig_infiles_; std::vector shuffle_indices_; std::vector > sub_shuffle_indices_; - std::vector > orig_rowsplits_; + // outer index: file index; inner vector: all distinct rowsplits for that file (one per ragged array group) + std::vector > > orig_rowsplits_; std::vector splits_; std::vector usebatch_; int randomcount_; diff --git a/src/djcdata/src/trainData.cpp b/src/djcdata/src/trainData.cpp index 0533e6d..3c34e3e 100644 --- a/src/djcdata/src/trainData.cpp +++ b/src/djcdata/src/trainData.cpp @@ -439,36 +439,39 @@ std::vector trainData::getFirstRowsplits()const{ return std::vector(); } -std::vector trainData::readShapesAndRowSplitsFromFile(const std::string& filename, bool checkConsistency){ - std::vector rowsplits; +std::vector> trainData::getAllRowsplits()const{ + std::vector> out; + const std::vector vv = {&feature_arrays_, &truth_arrays_, &weight_arrays_}; + for(const auto& a: vv){ + for(size_t i=0;isize();i++){ + const auto& rs = a->at(i).rowsplits(); + if(!rs.size()) continue; + bool found = false; + for(const auto& existing: out) + if(existing == rs){ found=true; break; } + if(!found) + out.push_back(rs); + } + } + return out; +} + +std::vector> trainData::readShapesAndRowSplitsFromFile(const std::string& filename){ + std::vector> all_rowsplits; FILE *ifile = fopen(filename.data(), "rb"); checkFile(ifile,filename); - //shapes - std::vector > dummy; readNested(feature_shapes_, ifile); readNested(truth_shapes_, ifile); readNested(weight_shapes_, ifile); - //features - readRowSplitArray(ifile,rowsplits,checkConsistency); - if(!checkConsistency && rowsplits.size()){ - fclose(ifile); - return rowsplits; - } - //truth - readRowSplitArray(ifile,rowsplits,checkConsistency); - if(!checkConsistency && rowsplits.size()){ - fclose(ifile); - return rowsplits; - } - //weights - readRowSplitArray(ifile,rowsplits,checkConsistency); + readRowSplitArray(ifile, all_rowsplits); // features + readRowSplitArray(ifile, all_rowsplits); // truth + readRowSplitArray(ifile, all_rowsplits); // weights fclose(ifile); - return rowsplits; - + return all_rowsplits; } void trainData::clear() { @@ -488,17 +491,17 @@ void trainData::checkFile(FILE *& ifile, const std::string& filename)const{ } -void trainData::readRowSplitArray(FILE *& ifile, std::vector &rowsplits, bool check)const{ +void trainData::readRowSplitArray(FILE *& ifile, std::vector> &all_rowsplits)const{ size_t size = 0; io::readFromFile(&size, ifile); for(size_t i=0;i::setBuffer: no features filled in trainData object"); - auto hasRagged = tdHasRaggedDimension(td); - auto rs = td.getFirstRowsplits(); - if(rs.size()) - orig_rowsplits_.push_back(rs); + auto allrs = td.getAllRowsplits(); + if(allrs.size()) + orig_rowsplits_.push_back(allrs); shuffle_indices_.push_back(0); std::vector vec; for(size_t i=0;i rowsplits = td.readShapesAndRowSplitsFromFile(f, firstfile);//check consistency only for first + auto allrs = td.readShapesAndRowSplitsFromFile(f); if(debuglevel>1) - std::cout << "rowsplits.size() " < allrs; for(size_t i=0;i file_combined_rs; + for(const auto& rs: file_rss){ + auto shuffled_rs = subShuffleRowSplits(rs, sub_shuffle_indices_.at(shuffled_idx)); + if(file_combined_rs.empty()){ + file_combined_rs = shuffled_rs; + } else { + auto nelems_a = simpleArrayBase::dataSplitToSplitIndices(file_combined_rs); + auto nelems_b = simpleArrayBase::dataSplitToSplitIndices(shuffled_rs); + if(nelems_a.size() != nelems_b.size()) + throw std::runtime_error("trainDataGenerator::prepareSplitting: rowsplit groups have different numbers of events in the same file"); + for(size_t j=0;j