From 4b02082d01c10e59134cad3a1fd14929c068fec4 Mon Sep 17 00:00:00 2001 From: Omniscimus Date: Wed, 3 Mar 2021 18:21:48 +0100 Subject: [PATCH] Reassign all variables in varMap, not only output variables --- maraboupy/MarabouNetworkONNX.py | 5 +++-- maraboupy/MarabouNetworkTF.py | 5 +++-- 2 files changed, 6 insertions(+), 4 deletions(-) mode change 100644 => 100755 maraboupy/MarabouNetworkONNX.py mode change 100644 => 100755 maraboupy/MarabouNetworkTF.py diff --git a/maraboupy/MarabouNetworkONNX.py b/maraboupy/MarabouNetworkONNX.py old mode 100644 new mode 100755 index 740b5cf638..d650b5e45b --- a/maraboupy/MarabouNetworkONNX.py +++ b/maraboupy/MarabouNetworkONNX.py @@ -829,7 +829,7 @@ def cleanShapes(self): self.shapeMap.pop(nodeName) def reassignVariable(self, var, numInVars, outVars, newOutVars): - """Reassign output variable so that output variables follow input variables + """Reassign variable so that output variables follow input variables This function computes what the given variable should be when the output variables are moved to come after the input variables. @@ -895,7 +895,8 @@ def reassignOutputVariables(self): self.upperBounds = newUpperBounds # Assign output variables to the new array - self.varMap[self.outputName] = newOutVars.reshape(self.shapeMap[self.outputName]) + for nodeName, variables in self.varMap.items(): + self.varMap[nodeName] = np.vectorize(self.reassignVariable, excluded=[1,2,3])(variables, numInVars, outVars, newOutVars) self.outputVars = self.varMap[self.outputName] def evaluateWithoutMarabou(self, inputValues): diff --git a/maraboupy/MarabouNetworkTF.py b/maraboupy/MarabouNetworkTF.py old mode 100644 new mode 100755 index 5db8b843f8..5e64926599 --- a/maraboupy/MarabouNetworkTF.py +++ b/maraboupy/MarabouNetworkTF.py @@ -879,8 +879,9 @@ def reassignOutputVariables(self): self.upperBounds = newUpperBounds # Assign output variables to the new array - self.outputVars = newOutVars.reshape(self.outputShape) - self.varMap[self.outputOp] = self.outputVars + for op, variables in self.varMap.items(): + self.varMap[op] = np.vectorize(self.reassignVariable, excluded=[1,2,3])(variables, numInVars, outVars, newOutVars) + self.outputVars = self.varMap[self.outputOp] def makeEquations(self, op): """Function to generate equations corresponding to given operation