Skip to content

Commit

Permalink
Better handling of malformed messages (#69)
Browse files Browse the repository at this point in the history
* Refactor the Payload for a small performance boost, and better handling of when the message is malformed. Fixes #67

* Testing if ameba passes without using alpine Crystal

* When converting the identifier back to JSON, use the original JSON instead of generating a new one to ensure the result is always the same
  • Loading branch information
jwoertink authored May 1, 2023
1 parent fb45456 commit b05e0f4
Show file tree
Hide file tree
Showing 7 changed files with 136 additions and 65 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ jobs:
- crystal_version: nightly
experimental: true
runs-on: ubuntu-latest
container: crystallang/crystal:${{ matrix.crystal_version }}-alpine
container: crystallang/crystal:${{ matrix.crystal_version }}
continue-on-error: ${{ matrix.experimental }}
services:
redis:
Expand Down
31 changes: 31 additions & 0 deletions spec/cable/connection_spec.cr
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,37 @@ describe Cable::Connection do
ConnectionTest::CHANNELS.keys.size.should eq(0)
end

describe "#receive" do
it "ignores empty messages" do
connect do |connection, socket|
connection.receive("")
sleep 0.1

socket.messages.size.should eq(0)

connection.close
socket.close
end
end

it "ignores incorrect json structures" do
connect do |connection, socket|
# The handler handles exception catching
# so we just make sure the correct exception is thrown
expect_raises(JSON::SerializableError) do
connection.receive([{command: "subscribe"}].to_json)
end

sleep 0.1

socket.messages.size.should eq(0)

connection.close
socket.close
end
end
end

describe "#subscribe" do
it "accepts" do
connect do |connection, socket|
Expand Down
2 changes: 1 addition & 1 deletion spec/cable/handler_spec.cr
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ describe Cable::Handler do

FakeExceptionService.size.should eq(1)
FakeExceptionService.exceptions.first.keys.first.should eq("Cable::Handler#socket.on_message")
FakeExceptionService.exceptions.first.values.first.class.should eq(KeyError)
FakeExceptionService.exceptions.first.values.first.class.should eq(JSON::SerializableError)
end

it "rejected" do
Expand Down
8 changes: 4 additions & 4 deletions spec/cable/payload_spec.cr
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ describe Cable::Payload do
}.to_json,
}.to_json

payload = Cable::Payload.new(payload_json)
payload = Cable::Payload.from_json(payload_json)
payload.command.should eq("subscribe")
payload.identifier.should eq({channel: "ChatChannel", person: {name: "Foo", age: 32, boom: "boom"}, foo: "bar"}.to_json)
payload.identifier.key.should eq({channel: "ChatChannel", person: {name: "Foo", age: 32, boom: "boom"}, foo: "bar"}.to_json)
payload.channel.should eq("ChatChannel")
payload.channel_params.should eq({"person" => {"name" => "Foo", "age" => 32, "boom" => "boom"}, "foo" => "bar"})
end
Expand All @@ -27,9 +27,9 @@ describe Cable::Payload do
data: {invite_id: 3, action: "invite"}.to_json,
}.to_json

payload = Cable::Payload.new(payload_json)
payload = Cable::Payload.from_json(payload_json)
payload.command.should eq("message")
payload.identifier.should eq({channel: "ChatChannel"}.to_json)
payload.identifier.key.should eq({channel: "ChatChannel"}.to_json)
payload.channel.should eq("ChatChannel")
payload.data.should eq({"invite_id" => 3})
payload.action?.should be_truthy
Expand Down
26 changes: 15 additions & 11 deletions src/cable/connection.cr
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,12 @@ module Cable
raise UnathorizedConnectionException.new
end

def receive(message)
payload = Cable::Payload.new(message)
# Convert the `message` to a proper `Payload`.
# The `Cable::Handler` will handle catching `SerializableError`,
# and close the socket and connection
def receive(message : String)
return unless message.presence
payload = Cable::Payload.from_json(message)

return subscribe(payload) if payload.command == "subscribe"
return unsubscribe(payload) if payload.command == "unsubscribe"
Expand All @@ -84,11 +88,11 @@ module Cable

channel = Cable::Channel::CHANNELS[payload.channel].new(
connection: self,
identifier: payload.identifier,
identifier: payload.identifier.key,
params: payload.channel_params
)
Connection::CHANNELS[connection_identifier] ||= {} of String => Cable::Channel
Connection::CHANNELS[connection_identifier][payload.identifier] = channel
Connection::CHANNELS[connection_identifier][payload.identifier.key] = channel
channel.subscribed

if channel.subscription_rejected?
Expand All @@ -102,36 +106,36 @@ module Cable
end

Cable::Logger.info { "#{payload.channel} is transmitting the subscription confirmation" }
socket.send({type: Cable.message(:confirmation), identifier: payload.identifier}.to_json)
socket.send({type: Cable.message(:confirmation), identifier: payload.identifier.key}.to_json)

channel.run_after_subscribed_callbacks unless channel.subscription_rejected?
end

# ensure we only allow subscribing to the same channel once from a connection
def connection_requesting_duplicate_channel_subscription?(payload)
return unless connection_key = Connection::CHANNELS.dig?(connection_identifier, payload.identifier)
return unless connection_key = Connection::CHANNELS.dig?(connection_identifier, payload.identifier.key)

connection_key.class.to_s == payload.channel
end

def unsubscribe(payload : Cable::Payload)
if channel = Connection::CHANNELS[connection_identifier].delete(payload.identifier)
if channel = Connection::CHANNELS[connection_identifier].delete(payload.identifier.key)
channel.close
Cable::Logger.info { "#{payload.channel} is transmitting the unsubscribe confirmation" }
socket.send({type: Cable.message(:unsubscribe), identifier: payload.identifier}.to_json)
socket.send({type: Cable.message(:unsubscribe), identifier: payload.identifier.key}.to_json)
end
end

def reject(payload : Cable::Payload)
if channel = Connection::CHANNELS[connection_identifier].delete(payload.identifier)
if channel = Connection::CHANNELS[connection_identifier].delete(payload.identifier.key)
channel.unsubscribed
Cable::Logger.info { "#{channel.class} is transmitting the subscription rejection" }
socket.send({type: Cable.message(:rejection), identifier: payload.identifier}.to_json)
socket.send({type: Cable.message(:rejection), identifier: payload.identifier.key}.to_json)
end
end

def message(payload : Cable::Payload)
if channel = Connection::CHANNELS.dig?(connection_identifier, payload.identifier)
if channel = Connection::CHANNELS.dig?(connection_identifier, payload.identifier.key)
if payload.action?
Cable::Logger.info { "#{channel.class}#perform(\"#{payload.action}\", #{payload.data})" }
channel.perform(payload.action, payload.data)
Expand Down
2 changes: 1 addition & 1 deletion src/cable/handler.cr
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ module Cable
socket.on_message do |message|
begin
connection.receive(message)
rescue e : KeyError | JSON::ParseException
rescue e : KeyError | JSON::ParseException | JSON::SerializableError
# handle unknown/malformed messages
socket.close(HTTP::WebSocket::CloseCode::InvalidFramePayloadData, "Invalid message")
Cable.server.remove_connection(connection_id)
Expand Down
130 changes: 83 additions & 47 deletions src/cable/payload.cr
Original file line number Diff line number Diff line change
@@ -1,82 +1,118 @@
module Cable
class Payload
struct Payload
include JSON::Serializable
include JSON::Serializable::Unmapped
alias RESULT = String | Int64 | Hash(String, RESULT)
alias PARAMS = Hash(String, RESULT)

getter json : String
getter action : String
getter command : String?
getter identifier : String
getter channel : String?
getter channel_params : Hash(String, Cable::Payload::RESULT) = Hash(String, RESULT).new
getter data : Hash(String, Cable::Payload::RESULT) = Hash(String, RESULT).new

def initialize(@json : String)
@parsed_json = JSON.parse(@json)
@action = ""
@is_action = false
@command = @parsed_json["command"].as_s
@identifier = @parsed_json["identifier"].as_s
@channel = process_channel
@channel_params = process_channel_params
@data = process_data
module IdentifierConverter
def self.from_json(value : JSON::PullParser) : Indentifier
key = value.read_string
i = Indentifier.from_json(key)
i.key = key
i
end

def self.to_json(value : Indentifier, json : JSON::Builder) : Nil
json.string(value.key)
end
end

struct Indentifier
include JSON::Serializable
include JSON::Serializable::Unmapped

property channel : String

# This is the original JSON used to parse this
# It's used as a unique key to map the different channels
@[JSON::Field(ignore: true)]
property key : String = ""
end

def action?
@is_action
@[JSON::Field]
getter command : String

@[JSON::Field(converter: Cable::Payload::IdentifierConverter)]
getter identifier : Indentifier

@[JSON::Field(ignore: true)]
getter action : String = ""

# After the Payload is deserialized, parse the data.
# This will ensure we know if it's an action.
def after_initialize
data
end

private def parsed_identifier
JSON.parse(@parsed_json["identifier"].as_s)
def channel : String
identifier.channel
end

private def process_channel
parsed_identifier["channel"].as_s
def action? : Bool
!action.presence.nil?
end

private def json_data
JSON.parse(@parsed_json["data"].as_s) if @parsed_json.as_h.has_key?("data")
@[JSON::Field(ignore: true)]
@_channel_params : Hash(String, RESULT)? = nil

# These are the additional data sent with the identifier
# e.g. `{channel: "RoomChannel", room_id: 1}`
# ```
# channel_params["room_id"] # => 1
# ```
def channel_params : Hash(String, RESULT)
if @_channel_params.nil?
@_channel_params = process_hash(identifier.json_unmapped)
else
@_channel_params.as(Hash(String, RESULT))
end
end

private def process_data
if jsd = json_data
params = jsd.as_h.dup
if deleted_action = params.delete("action")
@action = deleted_action.as_s
@is_action = true
end
@[JSON::Field(ignore: true)]
@_data : Hash(String, RESULT)? = nil

process_hash(params)
def data : Hash(String, RESULT)
if @_data.nil?
if unmapped_data = json_unmapped["data"]?
@_data = process_data(unmapped_data.as_s)
else
@_data = no_data
end
else
Hash(String, RESULT).new
@_data.as(Hash(String, RESULT))
end
end

private def process_channel_params
params = parsed_identifier.as_h.dup
params.delete("channel")
private def no_data : Hash(String, RESULT)
Hash(String, RESULT).new
end

process_hash(params)
private def process_hash(_params : Nil)
no_data
end

private def process_hash(params : Hash(String, JSON::Any))
params_result = Hash(String, RESULT).new

params.each do |k, v|
if v.as_s?
params_result[k] = v.as_s
elsif v.as_i64?
params_result[k] = v.as_i64
elsif v.as_h?
params_result[k] = process_hash(v)
if strval = v.as_s?
params_result[k] = strval
elsif intval = v.as_i64?
params_result[k] = intval
elsif hshval = v.as_h?
params_result[k] = process_hash(hshval)
end
end

params_result
end

private def process_hash(hash : JSON::Any)
process_hash(hash.as_h)
private def process_data(data_string : String)
json_data = JSON.parse(data_string).as_h?
hash = process_hash(json_data)
@action = hash.delete("action").to_s
hash
end
end
end

0 comments on commit b05e0f4

Please sign in to comment.