diff --git a/qadence/noise/protocols.py b/qadence/noise/protocols.py index af74e01e..fdb0f46d 100644 --- a/qadence/noise/protocols.py +++ b/qadence/noise/protocols.py @@ -149,16 +149,22 @@ def list(cls) -> list: return list(filter(lambda el: not el.startswith("__"), dir(cls))) def filter(self, protocol: NoiseEnum | str) -> NoiseHandler | None: - is_protocol: list = list() if protocol == NoiseProtocol.READOUT: - is_protocol = [p == protocol for p in self.protocol] + + def filter_fn(p: NoiseEnum | str) -> bool: + return p == protocol + else: - is_protocol = [isinstance(p, protocol) for p in self.protocol] # type: ignore[arg-type] + + def filter_fn(p: NoiseEnum | str) -> bool: + return isinstance(p, protocol) # type: ignore[arg-type] + + protocol_matches: list = list(filter(filter_fn, self.protocol)) # if we have at least a match - if sum(is_protocol) > 0: + if True in protocol_matches: return NoiseHandler( - list(compress(self.protocol, is_protocol)), - list(compress(self.options, is_protocol)), + list(compress(self.protocol, protocol_matches)), + list(compress(self.options, protocol_matches)), ) return None