Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions src/djcdata/interface/trainData.h
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,10 @@ class trainData{
void readMetaDataFromFile(const std::string& filename);

std::vector<int64_t> getFirstRowsplits()const;
std::vector<int64_t> readShapesAndRowSplitsFromFile(const std::string& filename, bool checkConsistency=true);
// Returns all distinct rowsplit vectors present in this trainData (one per ragged array group).
std::vector<std::vector<int64_t>> getAllRowsplits()const;
// Returns all distinct rowsplit vectors found in the file (one per ragged array, deduplicated).
std::vector<std::vector<int64_t>> readShapesAndRowSplitsFromFile(const std::string& filename);

void clear();

Expand Down Expand Up @@ -315,7 +318,8 @@ class trainData{
void checkFile(FILE *& f, const std::string& filename="")const;


void readRowSplitArray(FILE *&, std::vector<int64_t> &rs, bool check)const;
// Reads one typeContainer's worth of arrays from the file, collecting all distinct rowsplits found.
void readRowSplitArray(FILE *&, std::vector<std::vector<int64_t>> &all_rs)const;

std::vector<std::vector<int> > getShapes(const typeContainer& a)const;

Expand Down
3 changes: 2 additions & 1 deletion src/djcdata/interface/trainDataGenerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,8 @@ class trainDataGenerator{
std::vector<std::string> orig_infiles_;
std::vector<size_t> shuffle_indices_;
std::vector<std::vector<size_t> > sub_shuffle_indices_;
std::vector<std::vector<int64_t> > orig_rowsplits_;
// outer index: file index; inner vector: all distinct rowsplits for that file (one per ragged array group)
std::vector<std::vector<std::vector<int64_t> > > orig_rowsplits_;
std::vector<size_t> splits_;
std::vector<bool> usebatch_;
int randomcount_;
Expand Down
55 changes: 29 additions & 26 deletions src/djcdata/src/trainData.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -439,36 +439,39 @@ std::vector<int64_t> trainData::getFirstRowsplits()const{
return std::vector<int64_t>();
}

std::vector<int64_t> trainData::readShapesAndRowSplitsFromFile(const std::string& filename, bool checkConsistency){
std::vector<int64_t> rowsplits;
std::vector<std::vector<int64_t>> trainData::getAllRowsplits()const{
std::vector<std::vector<int64_t>> out;
const std::vector<const typeContainer* > vv = {&feature_arrays_, &truth_arrays_, &weight_arrays_};
for(const auto& a: vv){
for(size_t i=0;i<a->size();i++){
const auto& rs = a->at(i).rowsplits();
if(!rs.size()) continue;
bool found = false;
for(const auto& existing: out)
if(existing == rs){ found=true; break; }
if(!found)
out.push_back(rs);
}
}
return out;
}

std::vector<std::vector<int64_t>> trainData::readShapesAndRowSplitsFromFile(const std::string& filename){
std::vector<std::vector<int64_t>> all_rowsplits;

FILE *ifile = fopen(filename.data(), "rb");
checkFile(ifile,filename);

//shapes
std::vector<std::vector<int> > dummy;
readNested(feature_shapes_, ifile);
readNested(truth_shapes_, ifile);
readNested(weight_shapes_, ifile);

//features
readRowSplitArray(ifile,rowsplits,checkConsistency);
if(!checkConsistency && rowsplits.size()){
fclose(ifile);
return rowsplits;
}
//truth
readRowSplitArray(ifile,rowsplits,checkConsistency);
if(!checkConsistency && rowsplits.size()){
fclose(ifile);
return rowsplits;
}
//weights
readRowSplitArray(ifile,rowsplits,checkConsistency);
readRowSplitArray(ifile, all_rowsplits); // features
readRowSplitArray(ifile, all_rowsplits); // truth
readRowSplitArray(ifile, all_rowsplits); // weights

fclose(ifile);
return rowsplits;

return all_rowsplits;
}

void trainData::clear() {
Expand All @@ -488,17 +491,17 @@ void trainData::checkFile(FILE *& ifile, const std::string& filename)const{

}

void trainData::readRowSplitArray(FILE *& ifile, std::vector<int64_t> &rowsplits, bool check)const{
void trainData::readRowSplitArray(FILE *& ifile, std::vector<std::vector<int64_t>> &all_rowsplits)const{
size_t size = 0;
io::readFromFile(&size, ifile);
for(size_t i=0;i<size;i++){
auto frs = simpleArrayBase::readRowSplitsFromFileP(ifile, true);
if(frs.size()){
if(check){
if(rowsplits.size() && rowsplits != frs)
throw std::runtime_error("trainData::readShapesAndRowSplitsFromFile: row splits inconsistent");
}
rowsplits=frs;
bool found = false;
for(const auto& existing: all_rowsplits)
if(existing == frs){ found=true; break; }
if(!found)
all_rowsplits.push_back(frs);
}
}
}
Expand Down
41 changes: 29 additions & 12 deletions src/djcdata/src/trainDataGenerator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,10 @@ void trainDataGenerator::setBuffer(const trainData& td){
clear();
if(td.featureShapes().size()<1 || td.featureShapes().at(0).size()<1)
throw std::runtime_error("trainDataGenerator<T>::setBuffer: no features filled in trainData object");
auto hasRagged = tdHasRaggedDimension(td);

auto rs = td.getFirstRowsplits();
if(rs.size())
orig_rowsplits_.push_back(rs);
auto allrs = td.getAllRowsplits();
if(allrs.size())
orig_rowsplits_.push_back(allrs);
shuffle_indices_.push_back(0);
std::vector<size_t> vec;
for(size_t i=0;i<td.nElements();i++)
Expand Down Expand Up @@ -113,10 +112,11 @@ void trainDataGenerator::readInfo(){
hasRagged = tdHasRaggedDimension(td);
}
if(hasRagged){
std::vector<int64_t> rowsplits = td.readShapesAndRowSplitsFromFile(f, firstfile);//check consistency only for first
auto allrs = td.readShapesAndRowSplitsFromFile(f);
if(debuglevel>1)
std::cout << "rowsplits.size() " <<rowsplits.size() << ": "<<f << std::endl; //debuglevel
orig_rowsplits_.push_back(rowsplits);
std::cout << "rowsplits groups: " << allrs.size() << " (first size: " << (allrs.size()?allrs.at(0).size():0) << "): " << f << std::endl;
if(allrs.size())
orig_rowsplits_.push_back(allrs);
}
firstfile=false;
ntotal_ += td.nElements();
Expand Down Expand Up @@ -172,13 +172,30 @@ void trainDataGenerator::prepareSplitting(){
std::vector<int64_t> allrs;
for(size_t i=0;i<orig_rowsplits_.size();i++){
auto shuffled_idx = shuffle_indices_.at(i);
auto thisrs = orig_rowsplits_.at(shuffled_idx); //inject by file shuffle here
thisrs = subShuffleRowSplits(thisrs, sub_shuffle_indices_.at(shuffled_idx));
const auto& file_rss = orig_rowsplits_.at(shuffled_idx); //inject by file shuffle here

// Build a combined rowsplit for this file by summing sub-elements per event across all rowsplit groups.
// All groups must have the same number of events (same rowsplit length).
std::vector<int64_t> file_combined_rs;
for(const auto& rs: file_rss){
auto shuffled_rs = subShuffleRowSplits(rs, sub_shuffle_indices_.at(shuffled_idx));
if(file_combined_rs.empty()){
file_combined_rs = shuffled_rs;
} else {
auto nelems_a = simpleArrayBase::dataSplitToSplitIndices(file_combined_rs);
auto nelems_b = simpleArrayBase::dataSplitToSplitIndices(shuffled_rs);
if(nelems_a.size() != nelems_b.size())
throw std::runtime_error("trainDataGenerator::prepareSplitting: rowsplit groups have different numbers of events in the same file");
for(size_t j=0;j<nelems_a.size();j++)
nelems_a.at(j) += nelems_b.at(j);
file_combined_rs = simpleArrayBase::splitToDataSplitIndices(nelems_a);
}
}

if(i==0 || allrs.size()==0){
allrs=thisrs;}
else{
allrs = simpleArrayBase::mergeRowSplits(allrs,thisrs);
allrs = file_combined_rs;
} else {
allrs = simpleArrayBase::mergeRowSplits(allrs, file_combined_rs);
}
}

Expand Down