forked from home-assistant/core
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Snips ASR and NLU component (home-assistant#8156)
* Snips ASR and NLU component * Fix warning * Fix warnings * Fix lint issues * Add tests * Fix tabs * Fix newline * Fix quotes * Fix docstrings * Update tests * Remove logs * Fix lint warning * Update API * Fix Snips
- Loading branch information
1 parent
fe20e8e
commit ea11f20
Showing
2 changed files
with
191 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
""" | ||
Support for Snips on-device ASR and NLU. | ||
For more details about this component, please refer to the documentation at | ||
https://home-assistant.io/components/snips/ | ||
""" | ||
import asyncio | ||
import copy | ||
import json | ||
import logging | ||
import voluptuous as vol | ||
from homeassistant.helpers import template, script, config_validation as cv | ||
import homeassistant.loader as loader | ||
|
||
DOMAIN = 'snips' | ||
DEPENDENCIES = ['mqtt'] | ||
CONF_INTENTS = 'intents' | ||
CONF_ACTION = 'action' | ||
|
||
INTENT_TOPIC = 'hermes/nlu/intentParsed' | ||
|
||
LOGGER = logging.getLogger(__name__) | ||
|
||
CONFIG_SCHEMA = vol.Schema({ | ||
DOMAIN: { | ||
CONF_INTENTS: { | ||
cv.string: { | ||
vol.Optional(CONF_ACTION): cv.SCRIPT_SCHEMA, | ||
} | ||
} | ||
} | ||
}, extra=vol.ALLOW_EXTRA) | ||
|
||
INTENT_SCHEMA = vol.Schema({ | ||
vol.Required('text'): str, | ||
vol.Required('intent'): { | ||
vol.Required('intent_name'): str | ||
}, | ||
vol.Optional('slots'): [{ | ||
vol.Required('slot_name'): str, | ||
vol.Required('value'): { | ||
vol.Required('kind'): str, | ||
vol.Required('value'): cv.match_all | ||
} | ||
}] | ||
}, extra=vol.ALLOW_EXTRA) | ||
|
||
|
||
@asyncio.coroutine | ||
def async_setup(hass, config): | ||
"""Activate Snips component.""" | ||
mqtt = loader.get_component('mqtt') | ||
intents = config[DOMAIN].get(CONF_INTENTS, {}) | ||
handler = IntentHandler(hass, intents) | ||
|
||
@asyncio.coroutine | ||
def message_received(topic, payload, qos): | ||
"""Handle new messages on MQTT.""" | ||
LOGGER.debug("New intent: %s", payload) | ||
yield from handler.handle_intent(payload) | ||
|
||
yield from mqtt.async_subscribe(hass, INTENT_TOPIC, message_received) | ||
|
||
return True | ||
|
||
|
||
class IntentHandler(object): | ||
"""Help handling intents.""" | ||
|
||
def __init__(self, hass, intents): | ||
"""Initialize the intent handler.""" | ||
self.hass = hass | ||
intents = copy.deepcopy(intents) | ||
template.attach(hass, intents) | ||
|
||
for name, intent in intents.items(): | ||
if CONF_ACTION in intent: | ||
intent[CONF_ACTION] = script.Script( | ||
hass, intent[CONF_ACTION], "Snips intent {}".format(name)) | ||
|
||
self.intents = intents | ||
|
||
@asyncio.coroutine | ||
def handle_intent(self, payload): | ||
"""Handle an intent.""" | ||
try: | ||
response = json.loads(payload) | ||
except TypeError: | ||
LOGGER.error('Received invalid JSON: %s', payload) | ||
return | ||
|
||
try: | ||
response = INTENT_SCHEMA(response) | ||
except vol.Invalid as err: | ||
LOGGER.error('Intent has invalid schema: %s. %s', err, response) | ||
return | ||
|
||
intent = response['intent']['intent_name'].split('__')[-1] | ||
config = self.intents.get(intent) | ||
|
||
if config is None: | ||
LOGGER.warning("Received unknown intent %s. %s", intent, response) | ||
return | ||
|
||
action = config.get(CONF_ACTION) | ||
|
||
if action is not None: | ||
slots = self.parse_slots(response) | ||
yield from action.async_run(slots) | ||
|
||
def parse_slots(self, response): | ||
"""Parse the intent slots.""" | ||
parameters = {} | ||
|
||
for slot in response.get('slots', []): | ||
key = slot['slot_name'] | ||
value = self.get_value(slot['value']) | ||
if value is not None: | ||
parameters[key] = value | ||
|
||
return parameters | ||
|
||
@staticmethod | ||
def get_value(value): | ||
"""Return the value of a given slot.""" | ||
kind = value['kind'] | ||
|
||
if kind == "Custom": | ||
return value["value"] | ||
elif kind == "Builtin": | ||
try: | ||
return value["value"]["value"] | ||
except KeyError: | ||
return None | ||
else: | ||
LOGGER.warning('Received unknown slot type: %s', kind) | ||
|
||
return None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
"""Test the Snips component.""" | ||
import asyncio | ||
|
||
from homeassistant.bootstrap import async_setup_component | ||
from tests.common import async_fire_mqtt_message, async_mock_service | ||
|
||
EXAMPLE_MSG = """ | ||
{ | ||
"text": "turn the lights green", | ||
"intent": { | ||
"intent_name": "Lights", | ||
"probability": 1 | ||
}, | ||
"slots": [ | ||
{ | ||
"slot_name": "light_color", | ||
"value": { | ||
"kind": "Custom", | ||
"value": "blue" | ||
} | ||
} | ||
] | ||
} | ||
""" | ||
|
||
|
||
@asyncio.coroutine | ||
def test_snips_call_action(hass, mqtt_mock): | ||
"""Test calling action via Snips.""" | ||
calls = async_mock_service(hass, 'test', 'service') | ||
|
||
result = yield from async_setup_component(hass, "snips", { | ||
"snips": { | ||
"intents": { | ||
"Lights": { | ||
"action": { | ||
"service": "test.service", | ||
"data_template": { | ||
"color": "{{ light_color }}" | ||
} | ||
} | ||
} | ||
} | ||
} | ||
}) | ||
assert result | ||
|
||
async_fire_mqtt_message(hass, 'hermes/nlu/intentParsed', | ||
EXAMPLE_MSG) | ||
yield from hass.async_block_till_done() | ||
assert len(calls) == 1 | ||
call = calls[0] | ||
assert call.data.get('color') == 'blue' |