Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enforce permissions for Websocket API #18719

Merged
merged 2 commits into from
Nov 27, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion homeassistant/components/websocket_api/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from homeassistant.const import MATCH_ALL, EVENT_TIME_CHANGED
from homeassistant.core import callback, DOMAIN as HASS_DOMAIN
from homeassistant.exceptions import Unauthorized
from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.service import async_get_all_descriptions

Expand Down Expand Up @@ -98,6 +99,9 @@ def handle_subscribe_events(hass, connection, msg):

Async friendly.
"""
if not connection.user.is_admin:
raise Unauthorized

async def forward_events(event):
"""Forward events to websocket."""
if event.event_type == EVENT_TIME_CHANGED:
Expand Down Expand Up @@ -149,8 +153,14 @@ def handle_get_states(hass, connection, msg):

Async friendly.
"""
entity_perm = connection.user.permissions.check_entity
states = [
state for state in hass.states.async_all()
if entity_perm(state.entity_id, 'read')
]

connection.send_message(messages.result_message(
msg['id'], hass.states.async_all()))
msg['id'], states))


@decorators.async_response
Expand Down
25 changes: 20 additions & 5 deletions homeassistant/components/websocket_api/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import voluptuous as vol

from homeassistant.core import callback, Context
from homeassistant.exceptions import Unauthorized

from . import const, messages

Expand Down Expand Up @@ -63,11 +64,8 @@ def async_handle(self, msg):

try:
handler(self.hass, self, schema(msg))
except Exception: # pylint: disable=broad-except
self.logger.exception('Error handling message: %s', msg)
self.send_message(messages.error_message(
cur_id, const.ERR_UNKNOWN_ERROR,
'Unknown error.'))
except Exception as err: # pylint: disable=broad-except
self.async_handle_exception(msg, err)

self.last_id = cur_id

Expand All @@ -76,3 +74,20 @@ def async_close(self):
"""Close down connection."""
for unsub in self.event_listeners.values():
unsub()

@callback
def async_handle_exception(self, msg, err):
"""Handle an exception while processing a handler."""
if isinstance(err, Unauthorized):
code = const.ERR_UNAUTHORIZED
err_message = 'Unauthorized'
elif isinstance(err, vol.Invalid):
code = const.ERR_INVALID_FORMAT
err_message = 'Invalid format'
else:
self.logger.exception('Error handling message: %s', msg)
code = const.ERR_UNKNOWN_ERROR
err_message = 'Unknown error'

self.send_message(
messages.error_message(msg['id'], code, err_message))
11 changes: 6 additions & 5 deletions homeassistant/components/websocket_api/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@
URL = '/api/websocket'
MAX_PENDING_MSG = 512

ERR_ID_REUSE = 1
ERR_INVALID_FORMAT = 2
ERR_NOT_FOUND = 3
ERR_UNKNOWN_COMMAND = 4
ERR_UNKNOWN_ERROR = 5
ERR_ID_REUSE = 'id_reuse'
ERR_INVALID_FORMAT = 'invalid_format'
ERR_NOT_FOUND = 'not_found'
ERR_UNKNOWN_COMMAND = 'unknown_command'
ERR_UNKNOWN_ERROR = 'unknown_error'
ERR_UNAUTHORIZED = 'unauthorized'

TYPE_RESULT = 'result'

Expand Down
6 changes: 2 additions & 4 deletions homeassistant/components/websocket_api/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,8 @@ async def _handle_async_response(func, hass, connection, msg):
"""Create a response and handle exception."""
try:
await func(hass, connection, msg)
except Exception: # pylint: disable=broad-except
_LOGGER.exception("Unexpected exception")
connection.send_message(messages.error_message(
msg['id'], 'unknown', 'Unexpected error occurred'))
except Exception as err: # pylint: disable=broad-except
connection.async_handle_exception(msg, err)


def async_response(func):
Expand Down
5 changes: 3 additions & 2 deletions tests/components/websocket_api/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@


@pytest.fixture
def websocket_client(hass, hass_ws_client):
def websocket_client(hass, hass_ws_client, hass_access_token):
"""Create a websocket client."""
return hass.loop.run_until_complete(hass_ws_client(hass))
return hass.loop.run_until_complete(
hass_ws_client(hass, hass_access_token))


@pytest.fixture
Expand Down
39 changes: 39 additions & 0 deletions tests/components/websocket_api/test_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,3 +261,42 @@ async def test_call_service_context_no_user(hass, aiohttp_client):
assert call.service == 'test_service'
assert call.data == {'hello': 'world'}
assert call.context.user_id is None


async def test_subscribe_requires_admin(websocket_client, hass_admin_user):
"""Test subscribing events without being admin."""
hass_admin_user.groups = []
await websocket_client.send_json({
'id': 5,
'type': commands.TYPE_SUBSCRIBE_EVENTS,
'event_type': 'test_event'
})

msg = await websocket_client.receive_json()
assert not msg['success']
assert msg['error']['code'] == const.ERR_UNAUTHORIZED


async def test_states_filters_visible(hass, hass_admin_user, websocket_client):
"""Test we only get entities that we're allowed to see."""
hass_admin_user.mock_policy({
'entities': {
'entity_ids': {
'test.entity': True
}
}
})
hass.states.async_set('test.entity', 'hello')
hass.states.async_set('test.not_visible_entity', 'invisible')
await websocket_client.send_json({
'id': 5,
'type': commands.TYPE_GET_STATES,
})

msg = await websocket_client.receive_json()
assert msg['id'] == 5
assert msg['type'] == const.TYPE_RESULT
assert msg['success']

assert len(msg['result']) == 1
assert msg['result'][0]['entity_id'] == 'test.entity'