-
-
Notifications
You must be signed in to change notification settings - Fork 30.9k
/
trigger.py
376 lines (305 loc) · 11.1 KB
/
trigger.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
"""Triggers."""
from __future__ import annotations
import asyncio
from collections import defaultdict
from collections.abc import Callable, Coroutine
from dataclasses import dataclass, field
import functools
import logging
from typing import Any, Protocol, TypedDict, cast
import voluptuous as vol
from homeassistant.const import (
CONF_ALIAS,
CONF_ENABLED,
CONF_ID,
CONF_PLATFORM,
CONF_VARIABLES,
)
from homeassistant.core import (
CALLBACK_TYPE,
Context,
HassJob,
HomeAssistant,
callback,
is_callback,
)
from homeassistant.exceptions import HomeAssistantError, TemplateError
from homeassistant.loader import IntegrationNotFound, async_get_integration
from homeassistant.util.async_ import create_eager_task
from homeassistant.util.hass_dict import HassKey
from .template import Template
from .typing import ConfigType, TemplateVarsType
_PLATFORM_ALIASES = {
"device": "device_automation",
"event": "homeassistant",
"numeric_state": "homeassistant",
"state": "homeassistant",
"time_pattern": "homeassistant",
"time": "homeassistant",
}
DATA_PLUGGABLE_ACTIONS: HassKey[defaultdict[tuple, PluggableActionsEntry]] = HassKey(
"pluggable_actions"
)
class TriggerProtocol(Protocol):
"""Define the format of trigger modules.
Each module must define either TRIGGER_SCHEMA or async_validate_trigger_config.
"""
TRIGGER_SCHEMA: vol.Schema
async def async_validate_trigger_config(
self, hass: HomeAssistant, config: ConfigType
) -> ConfigType:
"""Validate config."""
async def async_attach_trigger(
self,
hass: HomeAssistant,
config: ConfigType,
action: TriggerActionType,
trigger_info: TriggerInfo,
) -> CALLBACK_TYPE:
"""Attach a trigger."""
class TriggerActionType(Protocol):
"""Protocol type for trigger action callback."""
async def __call__(
self,
run_variables: dict[str, Any],
context: Context | None = None,
) -> Any:
"""Define action callback type."""
class TriggerData(TypedDict):
"""Trigger data."""
id: str
idx: str
alias: str | None
class TriggerInfo(TypedDict):
"""Information about trigger."""
domain: str
name: str
home_assistant_start: bool
variables: TemplateVarsType
trigger_data: TriggerData
@dataclass(slots=True)
class PluggableActionsEntry:
"""Holder to keep track of all plugs and actions for a given trigger."""
plugs: set[PluggableAction] = field(default_factory=set)
actions: dict[
object,
tuple[
HassJob[[dict[str, Any], Context | None], Coroutine[Any, Any, None]],
dict[str, Any],
],
] = field(default_factory=dict)
class PluggableAction:
"""A pluggable action handler."""
_entry: PluggableActionsEntry | None = None
def __init__(self, update: CALLBACK_TYPE | None = None) -> None:
"""Initialize a pluggable action.
:param update: callback triggered whenever triggers are attached or removed.
"""
self._update = update
def __bool__(self) -> bool:
"""Return if we have something attached."""
return bool(self._entry and self._entry.actions)
@callback
def async_run_update(self) -> None:
"""Run update function if one exists."""
if self._update:
self._update()
@staticmethod
@callback
def async_get_registry(hass: HomeAssistant) -> dict[tuple, PluggableActionsEntry]:
"""Return the pluggable actions registry."""
if data := hass.data.get(DATA_PLUGGABLE_ACTIONS):
return data
data = hass.data[DATA_PLUGGABLE_ACTIONS] = defaultdict(PluggableActionsEntry)
return data
@staticmethod
@callback
def async_attach_trigger(
hass: HomeAssistant,
trigger: dict[str, str],
action: TriggerActionType,
variables: dict[str, Any],
) -> CALLBACK_TYPE:
"""Attach an action to a trigger entry.
Existing or future plugs registered will be attached.
"""
reg = PluggableAction.async_get_registry(hass)
key = tuple(sorted(trigger.items()))
entry = reg[key]
def _update() -> None:
for plug in entry.plugs:
plug.async_run_update()
@callback
def _remove() -> None:
"""Remove this action attachment, and disconnect all plugs."""
del entry.actions[_remove]
_update()
if not entry.actions and not entry.plugs:
del reg[key]
job = HassJob(action, f"trigger {trigger} {variables}")
entry.actions[_remove] = (job, variables)
_update()
return _remove
@callback
def async_register(
self, hass: HomeAssistant, trigger: dict[str, str]
) -> CALLBACK_TYPE:
"""Register plug in the global plugs dictionary."""
reg = PluggableAction.async_get_registry(hass)
key = tuple(sorted(trigger.items()))
self._entry = reg[key]
self._entry.plugs.add(self)
@callback
def _remove() -> None:
"""Remove plug from registration.
Clean up entry if there are no actions or plugs registered.
"""
assert self._entry
self._entry.plugs.remove(self)
if not self._entry.actions and not self._entry.plugs:
del reg[key]
self._entry = None
return _remove
async def async_run(
self, hass: HomeAssistant, context: Context | None = None
) -> None:
"""Run all actions."""
assert self._entry
for job, variables in self._entry.actions.values():
task = hass.async_run_hass_job(job, variables, context)
if task:
await task
async def _async_get_trigger_platform(
hass: HomeAssistant, config: ConfigType
) -> TriggerProtocol:
platform_and_sub_type = config[CONF_PLATFORM].split(".")
platform = platform_and_sub_type[0]
platform = _PLATFORM_ALIASES.get(platform, platform)
try:
integration = await async_get_integration(hass, platform)
except IntegrationNotFound:
raise vol.Invalid(f"Invalid platform '{platform}' specified") from None
try:
return await integration.async_get_platform("trigger")
except ImportError:
raise vol.Invalid(
f"Integration '{platform}' does not provide trigger support"
) from None
async def async_validate_trigger_config(
hass: HomeAssistant, trigger_config: list[ConfigType]
) -> list[ConfigType]:
"""Validate triggers."""
config = []
for conf in trigger_config:
platform = await _async_get_trigger_platform(hass, conf)
if hasattr(platform, "async_validate_trigger_config"):
conf = await platform.async_validate_trigger_config(hass, conf)
else:
conf = platform.TRIGGER_SCHEMA(conf)
config.append(conf)
return config
def _trigger_action_wrapper(
hass: HomeAssistant, action: Callable, conf: ConfigType
) -> Callable:
"""Wrap trigger action with extra vars if configured.
If action is a coroutine function, a coroutine function will be returned.
If action is a callback, a callback will be returned.
"""
if CONF_VARIABLES not in conf:
return action
# Check for partials to properly determine if coroutine function
check_func = action
while isinstance(check_func, functools.partial):
check_func = check_func.func
wrapper_func: Callable[..., None] | Callable[..., Coroutine[Any, Any, None]]
if asyncio.iscoroutinefunction(check_func):
async_action = cast(Callable[..., Coroutine[Any, Any, None]], action)
@functools.wraps(async_action)
async def async_with_vars(
run_variables: dict[str, Any], context: Context | None = None
) -> None:
"""Wrap action with extra vars."""
trigger_variables = conf[CONF_VARIABLES]
run_variables.update(trigger_variables.async_render(hass, run_variables))
await action(run_variables, context)
wrapper_func = async_with_vars
else:
@functools.wraps(action)
async def with_vars(
run_variables: dict[str, Any], context: Context | None = None
) -> None:
"""Wrap action with extra vars."""
trigger_variables = conf[CONF_VARIABLES]
run_variables.update(trigger_variables.async_render(hass, run_variables))
action(run_variables, context)
if is_callback(check_func):
with_vars = callback(with_vars)
wrapper_func = with_vars
return wrapper_func
async def async_initialize_triggers(
hass: HomeAssistant,
trigger_config: list[ConfigType],
action: Callable,
domain: str,
name: str,
log_cb: Callable,
home_assistant_start: bool = False,
variables: TemplateVarsType = None,
) -> CALLBACK_TYPE | None:
"""Initialize triggers."""
triggers: list[asyncio.Task[CALLBACK_TYPE]] = []
for idx, conf in enumerate(trigger_config):
# Skip triggers that are not enabled
if CONF_ENABLED in conf:
enabled = conf[CONF_ENABLED]
if isinstance(enabled, Template):
try:
enabled = enabled.async_render(variables, limited=True)
except TemplateError as err:
log_cb(logging.ERROR, f"Error rendering enabled template: {err}")
continue
if not enabled:
continue
platform = await _async_get_trigger_platform(hass, conf)
trigger_id = conf.get(CONF_ID, f"{idx}")
trigger_idx = f"{idx}"
trigger_alias = conf.get(CONF_ALIAS)
trigger_data = TriggerData(id=trigger_id, idx=trigger_idx, alias=trigger_alias)
info = TriggerInfo(
domain=domain,
name=name,
home_assistant_start=home_assistant_start,
variables=variables,
trigger_data=trigger_data,
)
triggers.append(
create_eager_task(
platform.async_attach_trigger(
hass, conf, _trigger_action_wrapper(hass, action, conf), info
)
)
)
attach_results = await asyncio.gather(*triggers, return_exceptions=True)
removes: list[Callable[[], None]] = []
for result in attach_results:
if isinstance(result, HomeAssistantError):
log_cb(logging.ERROR, f"Got error '{result}' when setting up triggers for")
elif isinstance(result, Exception):
log_cb(logging.ERROR, "Error setting up trigger", exc_info=result)
elif isinstance(result, BaseException):
raise result from None
elif result is None:
log_cb( # type: ignore[unreachable]
logging.ERROR, "Unknown error while setting up trigger (empty result)"
)
else:
removes.append(result)
if not removes:
return None
log_cb(logging.INFO, "Initialized trigger")
@callback
def remove_triggers() -> None:
"""Remove triggers."""
for remove in removes:
remove()
return remove_triggers