diff --git a/elastalert/__init__.py b/elastalert/__init__.py index 55bfdb32f..38100b0eb 100644 --- a/elastalert/__init__.py +++ b/elastalert/__init__.py @@ -16,7 +16,7 @@ def __init__(self, conf): """ :arg conf: es_conn_config dictionary. Ref. :func:`~util.build_es_conn_config` """ - super(ElasticSearchClient, self).__init__(host=conf['es_host'], + super(ElasticSearchClient, self).__init__(hosts=conf['es_host'], port=conf['es_port'], url_prefix=conf['es_url_prefix'], use_ssl=conf['use_ssl'], diff --git a/elastalert/util.py b/elastalert/util.py index bbb0600ff..4c5cff412 100644 --- a/elastalert/util.py +++ b/elastalert/util.py @@ -343,8 +343,10 @@ def build_es_conn_config(conf): parsed_conf['es_password'] = None parsed_conf['aws_region'] = None parsed_conf['profile'] = None - parsed_conf['es_host'] = os.environ.get('ES_HOST', conf['es_host']) - parsed_conf['es_port'] = int(os.environ.get('ES_PORT', conf['es_port'])) + es_host = os.environ.get('ES_HOST', conf['es_host']) + es_port = int(os.environ.get('ES_PORT', conf['es_port'])) + parsed_conf['es_host'] = parse_host(es_host, es_port) + parsed_conf['es_port'] = es_port parsed_conf['es_url_prefix'] = '' parsed_conf['es_conn_timeout'] = conf.get('es_conn_timeout', 20) parsed_conf['send_get_body_as'] = conf.get('es_send_get_body_as', 'GET') @@ -460,3 +462,19 @@ def should_scrolling_continue(rule_conf): stop_the_scroll = 0 < max_scrolling <= rule_conf.get('scrolling_cycle') return not stop_the_scroll + +def parse_host(host, port=9200): + """ + Convet host str like "host1:port1, host2:port2" to list + + :param host str: hostnames (separated with comma ) or single host name + :param port: default to 9200 + :return: list of hosts + """ + if "," in host: + host_list = host.split(",") + host_list = [x.strip() for x in host_list] + return host_list + else: + return ["{host}:{port}".format(host=host, port=port)] + diff --git a/tests/util_test.py b/tests/util_test.py index 55a2f9c8f..72d2b6ab2 100644 --- a/tests/util_test.py +++ b/tests/util_test.py @@ -6,7 +6,7 @@ import pytest from dateutil.parser import parse as dt -from elastalert.util import add_raw_postfix +from elastalert.util import add_raw_postfix, parse_host, build_es_conn_config from elastalert.util import format_index from elastalert.util import lookup_es_key from elastalert.util import parse_deadline @@ -228,3 +228,21 @@ def test_should_scrolling_continue(): assert should_scrolling_continue(rule_before_first_run) is True assert should_scrolling_continue(rule_before_max_scrolling) is True assert should_scrolling_continue(rule_over_max_scrolling) is False + + +def test_parse_host(): + assert parse_host("localhost", port=9200) == ["localhost:9200"] + assert parse_host("host1:9200, host2:9200, host3:9300") ==["host1:9200", + "host2:9200", + "host3:9300"] + +def test_build_cofig_for_multi(): + assert build_es_conn_config({ + "es_host":"localhost", + "es_port": 9200 + })['es_host'] == ['localhost:9200'] + + assert build_es_conn_config({ + "es_host": "host1:9200, host2:9200, host3:9300", + "es_port": 9200 + })['es_host'] == ["host1:9200","host2:9200","host3:9300"]