Skip to content

Commit

Permalink
Avoid calling ChannelManager when uninitialized
Browse files Browse the repository at this point in the history
  • Loading branch information
johnmaguire committed Feb 14, 2024
1 parent 2be01e7 commit 5dd01fc
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 8 deletions.
20 changes: 15 additions & 5 deletions cardinal/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,9 @@ def signedOn(self):
for channel in self.factory.channels:
self.join(channel)

# ChannelManager is only created if CHANMODES is supported
self.channels = None

# Set the uptime as now and grab the boot time from the factory
self.uptime = datetime.now()
self.booted = self.factory.booted
Expand All @@ -172,7 +175,8 @@ def joined(self, channel):
channel -- Channel joined. Provided by Twisted.
"""
self.logger.info("Joined %s" % channel)
self.channels.add(channel)
if self.channels:
self.channels.add(channel)

# Request the channel modes for this channel
self.sendLine("MODE {}".format(channel))
Expand All @@ -183,23 +187,26 @@ def irc_RPL_CHANNELMODEIS(self, prefix, params):
if modes[0] not in "-+":
modes = "+" + modes

self.channels.set_modes(channel, modes, args)
if self.channels:
self.channels.set_modes(channel, modes, args)

def left(self, channel):
"""Called when we leave a channel.
channel -- Channel joined. Provided by Twisted.
"""
self.logger.info("Parted %s" % channel)
self.channels.remove(channel)
if self.channels:
self.channels.remove(channel)

def kickedFrom(self, channel):
"""Called when we leave a channel.
channel -- Channel joined. Provided by Twisted.
"""
self.logger.info("Kicked from %s" % channel)
self.channels.remove(channel)
if self.channels:
self.channels.remove(channel)

def lineReceived(self, line):
"""Called for every line received from the server."""
Expand Down Expand Up @@ -543,7 +550,7 @@ def sendMsg(self, channel, message, length=None):
length -- Length of message. Twisted will calculate if None given.
"""
try:
if not self.channels[channel].allows_color():
if self.channels and not self.channels[channel].allows_color():
message = strip_formatting(message)
except KeyError:
pass
Expand Down Expand Up @@ -803,6 +810,9 @@ def __getitem__(self, key):
def __iter__(self):
return iter(self._channels)

def __bool__(self):
return True

def add(self, name):
self._channels[name] = Channel(name)

Expand Down
6 changes: 3 additions & 3 deletions cardinal/test_bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,11 +188,11 @@ def test_signedOn_sends_server_commands(
call('MODE {} +B'.format(self.factory.nickname))
])

@patch('cardinal.bot.irc.IRCClient.sendLine')
def test_joined(self, mock_sendline):
@patch.object(CardinalBot, 'send')
def test_joined(self, mock_send):
self.cardinal.joined("#bots")
# need to request modes to track channel
mock_sendline.assert_called_once_with("MODE #bots")
mock_send.assert_called_once_with("MODE #bots")

@patch('cardinal.bot.irc.IRCClient.lineReceived')
def test_lineReceived(self, mock_parent_linereceived):
Expand Down

0 comments on commit 5dd01fc

Please sign in to comment.