diff --git a/python_coreml_stable_diffusion/pipeline.py b/python_coreml_stable_diffusion/pipeline.py index c7ab63de..e3d46c50 100644 --- a/python_coreml_stable_diffusion/pipeline.py +++ b/python_coreml_stable_diffusion/pipeline.py @@ -161,7 +161,7 @@ def _encode_prompt(self, # tokenize without max_length to catch any truncation untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="np").input_ids - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not np.equal( + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not np.array_equal( text_input_ids, untruncated_ids ): removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1: -1])