diff --git a/src/models/experimental.py b/src/models/experimental.py index 5a7ca27..1ad7d6a 100644 --- a/src/models/experimental.py +++ b/src/models/experimental.py @@ -287,12 +287,12 @@ def forward(self, x): return x -def attempt_load(weights, map_location=None): +def attempt_load(weights, map_location=None, weights_only=False): # Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a model = Ensemble() for w in weights if isinstance(weights, list) else [weights]: # attempt_download(w) - ckpt = torch.load(w, map_location=map_location) # load + ckpt = torch.load(w, map_location=map_location, weights_only=weights_only) # load model.append( ckpt["ema" if ckpt.get("ema") else "model"].float().fuse().eval() ) # FP32 model