diff --git a/hackable_diffusion/lib/corruption/discrete.py b/hackable_diffusion/lib/corruption/discrete.py index a1cb4d2..bfe6c21 100644 --- a/hackable_diffusion/lib/corruption/discrete.py +++ b/hackable_diffusion/lib/corruption/discrete.py @@ -27,6 +27,7 @@ import jax import jax.numpy as jnp import kauldron.ktyping as kt +import numpy as np ################################################################################ # MARK: Constants @@ -202,10 +203,10 @@ def is_masking(self) -> bool: return False else: invariant_probs_masking = (0.0,) * self.num_categories + (1.0,) - invariant_probs_masking_vec = jnp.array(invariant_probs_masking) - return jnp.all( - self.invariant_probs_vec == invariant_probs_masking_vec - ).item() + invariant_probs_masking_vec = np.array(invariant_probs_masking) + return np.all( + np.array(self.invariant_probs) == invariant_probs_masking_vec + ) ############################################################################## # MARK: Methods