Reorganize HTTP component (#4575)

* Move HTTP to own folder

* Break HTTP into middlewares

* Lint

* Split tests per middleware

* Clean up HTTP tests

* Make HomeAssistantViews more stateless

* Lint

* Make HTTP setup async
This commit is contained in:
Paulus Schoutsen 2016-11-25 13:04:06 -08:00 committed by GitHub
parent 58b85b2e0e
commit 32ffd006fa
35 changed files with 1318 additions and 1084 deletions

View File

@ -118,7 +118,7 @@ class AlexaIntentsView(HomeAssistantView):
def __init__(self, hass, intents):
"""Initialize Alexa view."""
super().__init__(hass)
super().__init__()
intents = copy.deepcopy(intents)
template.attach(hass, intents)
@ -150,7 +150,7 @@ class AlexaIntentsView(HomeAssistantView):
return None
intent = req.get('intent')
response = AlexaResponse(self.hass, intent)
response = AlexaResponse(request.app['hass'], intent)
if req_type == 'LaunchRequest':
response.add_speech(
@ -282,7 +282,7 @@ class AlexaFlashBriefingView(HomeAssistantView):
def __init__(self, hass, flash_briefings):
"""Initialize Alexa view."""
super().__init__(hass)
super().__init__()
self.flash_briefings = copy.deepcopy(flash_briefings)
template.attach(hass, self.flash_briefings)

View File

@ -77,8 +77,10 @@ class APIEventStream(HomeAssistantView):
@asyncio.coroutine
def get(self, request):
"""Provide a streaming interface for the event bus."""
# pylint: disable=no-self-use
hass = request.app['hass']
stop_obj = object()
to_write = asyncio.Queue(loop=self.hass.loop)
to_write = asyncio.Queue(loop=hass.loop)
restrict = request.GET.get('restrict')
if restrict:
@ -106,7 +108,7 @@ class APIEventStream(HomeAssistantView):
response.content_type = 'text/event-stream'
yield from response.prepare(request)
unsub_stream = self.hass.bus.async_listen(MATCH_ALL, forward_events)
unsub_stream = hass.bus.async_listen(MATCH_ALL, forward_events)
try:
_LOGGER.debug('STREAM %s ATTACHED', id(stop_obj))
@ -117,7 +119,7 @@ class APIEventStream(HomeAssistantView):
while True:
try:
with async_timeout.timeout(STREAM_PING_INTERVAL,
loop=self.hass.loop):
loop=hass.loop):
payload = yield from to_write.get()
if payload is stop_obj:
@ -145,7 +147,7 @@ class APIConfigView(HomeAssistantView):
@ha.callback
def get(self, request):
"""Get current configuration."""
return self.json(self.hass.config.as_dict())
return self.json(request.app['hass'].config.as_dict())
class APIDiscoveryView(HomeAssistantView):
@ -158,10 +160,11 @@ class APIDiscoveryView(HomeAssistantView):
@ha.callback
def get(self, request):
"""Get discovery info."""
needs_auth = self.hass.config.api.api_password is not None
hass = request.app['hass']
needs_auth = hass.config.api.api_password is not None
return self.json({
'base_url': self.hass.config.api.base_url,
'location_name': self.hass.config.location_name,
'base_url': hass.config.api.base_url,
'location_name': hass.config.location_name,
'requires_api_password': needs_auth,
'version': __version__
})
@ -176,7 +179,7 @@ class APIStatesView(HomeAssistantView):
@ha.callback
def get(self, request):
"""Get current states."""
return self.json(self.hass.states.async_all())
return self.json(request.app['hass'].states.async_all())
class APIEntityStateView(HomeAssistantView):
@ -188,7 +191,7 @@ class APIEntityStateView(HomeAssistantView):
@ha.callback
def get(self, request, entity_id):
"""Retrieve state of entity."""
state = self.hass.states.get(entity_id)
state = request.app['hass'].states.get(entity_id)
if state:
return self.json(state)
else:
@ -197,6 +200,7 @@ class APIEntityStateView(HomeAssistantView):
@asyncio.coroutine
def post(self, request, entity_id):
"""Update state of entity."""
hass = request.app['hass']
try:
data = yield from request.json()
except ValueError:
@ -211,15 +215,14 @@ class APIEntityStateView(HomeAssistantView):
attributes = data.get('attributes')
force_update = data.get('force_update', False)
is_new_state = self.hass.states.get(entity_id) is None
is_new_state = hass.states.get(entity_id) is None
# Write state
self.hass.states.async_set(entity_id, new_state, attributes,
force_update)
hass.states.async_set(entity_id, new_state, attributes, force_update)
# Read the state back for our response
status_code = HTTP_CREATED if is_new_state else 200
resp = self.json(self.hass.states.get(entity_id), status_code)
resp = self.json(hass.states.get(entity_id), status_code)
resp.headers.add('Location', URL_API_STATES_ENTITY.format(entity_id))
@ -228,7 +231,7 @@ class APIEntityStateView(HomeAssistantView):
@ha.callback
def delete(self, request, entity_id):
"""Remove entity."""
if self.hass.states.async_remove(entity_id):
if request.app['hass'].states.async_remove(entity_id):
return self.json_message('Entity removed')
else:
return self.json_message('Entity not found', HTTP_NOT_FOUND)
@ -243,7 +246,7 @@ class APIEventListenersView(HomeAssistantView):
@ha.callback
def get(self, request):
"""Get event listeners."""
return self.json(async_events_json(self.hass))
return self.json(async_events_json(request.app['hass']))
class APIEventView(HomeAssistantView):
@ -271,7 +274,8 @@ class APIEventView(HomeAssistantView):
if state:
event_data[key] = state
self.hass.bus.async_fire(event_type, event_data, ha.EventOrigin.remote)
request.app['hass'].bus.async_fire(event_type, event_data,
ha.EventOrigin.remote)
return self.json_message("Event {} fired.".format(event_type))
@ -285,7 +289,7 @@ class APIServicesView(HomeAssistantView):
@ha.callback
def get(self, request):
"""Get registered services."""
return self.json(async_services_json(self.hass))
return self.json(async_services_json(request.app['hass']))
class APIDomainServicesView(HomeAssistantView):
@ -300,12 +304,12 @@ class APIDomainServicesView(HomeAssistantView):
Returns a list of changed states.
"""
hass = request.app['hass']
body = yield from request.text()
data = json.loads(body) if body else None
with AsyncTrackStates(self.hass) as changed_states:
yield from self.hass.services.async_call(domain, service, data,
True)
with AsyncTrackStates(hass) as changed_states:
yield from hass.services.async_call(domain, service, data, True)
return self.json(changed_states)
@ -320,6 +324,7 @@ class APIEventForwardingView(HomeAssistantView):
@asyncio.coroutine
def post(self, request):
"""Setup an event forwarder."""
hass = request.app['hass']
try:
data = yield from request.json()
except ValueError:
@ -340,14 +345,14 @@ class APIEventForwardingView(HomeAssistantView):
api = rem.API(host, api_password, port)
valid = yield from self.hass.loop.run_in_executor(
valid = yield from hass.loop.run_in_executor(
None, api.validate_api)
if not valid:
return self.json_message("Unable to validate API.",
HTTP_UNPROCESSABLE_ENTITY)
if self.event_forwarder is None:
self.event_forwarder = rem.EventForwarder(self.hass)
self.event_forwarder = rem.EventForwarder(hass)
self.event_forwarder.async_connect(api)
@ -389,7 +394,7 @@ class APIComponentsView(HomeAssistantView):
@ha.callback
def get(self, request):
"""Get current loaded components."""
return self.json(self.hass.config.components)
return self.json(request.app['hass'].config.components)
class APIErrorLogView(HomeAssistantView):
@ -402,7 +407,7 @@ class APIErrorLogView(HomeAssistantView):
def get(self, request):
"""Serve error log."""
resp = yield from self.file(
request, self.hass.config.path(ERROR_LOG_FILENAME))
request, request.app['hass'].config.path(ERROR_LOG_FILENAME))
return resp
@ -417,7 +422,7 @@ class APITemplateView(HomeAssistantView):
"""Render a template."""
try:
data = yield from request.json()
tpl = template.Template(data['template'], self.hass)
tpl = template.Template(data['template'], request.app['hass'])
return tpl.async_render(data.get('variables'))
except (ValueError, TemplateError) as ex:
return self.json_message('Error rendering template: {}'.format(ex),

View File

@ -13,7 +13,7 @@ from aiohttp import web
from homeassistant.helpers.entity import Entity
from homeassistant.helpers.entity_component import EntityComponent
from homeassistant.helpers.config_validation import PLATFORM_SCHEMA # noqa
from homeassistant.components.http import HomeAssistantView
from homeassistant.components.http import HomeAssistantView, KEY_AUTHENTICATED
DOMAIN = 'camera'
DEPENDENCIES = ['http']
@ -33,8 +33,8 @@ def async_setup(hass, config):
component = EntityComponent(
logging.getLogger(__name__), DOMAIN, hass, SCAN_INTERVAL)
hass.http.register_view(CameraImageView(hass, component.entities))
hass.http.register_view(CameraMjpegStream(hass, component.entities))
hass.http.register_view(CameraImageView(component.entities))
hass.http.register_view(CameraMjpegStream(component.entities))
yield from component.async_setup(config)
return True
@ -165,9 +165,8 @@ class CameraView(HomeAssistantView):
requires_auth = False
def __init__(self, hass, entities):
def __init__(self, entities):
"""Initialize a basic camera view."""
super().__init__(hass)
self.entities = entities
@asyncio.coroutine
@ -178,7 +177,7 @@ class CameraView(HomeAssistantView):
if camera is None:
return web.Response(status=404)
authenticated = (request.authenticated or
authenticated = (request[KEY_AUTHENTICATED] or
request.GET.get('token') == camera.access_token)
if not authenticated:

View File

@ -21,7 +21,7 @@ DEPENDENCIES = ['http']
def setup_scanner(hass, config, see):
"""Setup an endpoint for the GPSLogger application."""
hass.http.register_view(GPSLoggerView(hass, see))
hass.http.register_view(GPSLoggerView(see))
return True
@ -32,20 +32,18 @@ class GPSLoggerView(HomeAssistantView):
url = '/api/gpslogger'
name = 'api:gpslogger'
def __init__(self, hass, see):
def __init__(self, see):
"""Initialize GPSLogger url endpoints."""
super().__init__(hass)
self.see = see
@asyncio.coroutine
def get(self, request):
"""A GPSLogger message received as GET."""
res = yield from self._handle(request.GET)
res = yield from self._handle(request.app['hass'], request.GET)
return res
@asyncio.coroutine
# pylint: disable=too-many-return-statements
def _handle(self, data):
def _handle(self, hass, data):
"""Handle gpslogger request."""
if 'latitude' not in data or 'longitude' not in data:
return ('Latitude and longitude not specified.',
@ -66,7 +64,7 @@ class GPSLoggerView(HomeAssistantView):
if 'battery' in data:
battery = float(data['battery'])
yield from self.hass.loop.run_in_executor(
yield from hass.loop.run_in_executor(
None, partial(self.see, dev_id=device,
gps=gps_location, battery=battery,
gps_accuracy=accuracy))

View File

@ -23,7 +23,7 @@ DEPENDENCIES = ['http']
def setup_scanner(hass, config, see):
"""Setup an endpoint for the Locative application."""
hass.http.register_view(LocativeView(hass, see))
hass.http.register_view(LocativeView(see))
return True
@ -34,27 +34,26 @@ class LocativeView(HomeAssistantView):
url = '/api/locative'
name = 'api:locative'
def __init__(self, hass, see):
def __init__(self, see):
"""Initialize Locative url endpoints."""
super().__init__(hass)
self.see = see
@asyncio.coroutine
def get(self, request):
"""Locative message received as GET."""
res = yield from self._handle(request.GET)
res = yield from self._handle(request.app['hass'], request.GET)
return res
@asyncio.coroutine
def post(self, request):
"""Locative message received."""
data = yield from request.post()
res = yield from self._handle(data)
res = yield from self._handle(request.app['hass'], data)
return res
@asyncio.coroutine
# pylint: disable=too-many-return-statements
def _handle(self, data):
def _handle(self, hass, data):
"""Handle locative request."""
if 'latitude' not in data or 'longitude' not in data:
return ('Latitude and longitude not specified.',
@ -81,19 +80,19 @@ class LocativeView(HomeAssistantView):
gps_location = (data[ATTR_LATITUDE], data[ATTR_LONGITUDE])
if direction == 'enter':
yield from self.hass.loop.run_in_executor(
yield from hass.loop.run_in_executor(
None, partial(self.see, dev_id=device,
location_name=location_name,
gps=gps_location))
return 'Setting location to {}'.format(location_name)
elif direction == 'exit':
current_state = self.hass.states.get(
current_state = hass.states.get(
'{}.{}'.format(DOMAIN, device))
if current_state is None or current_state.state == location_name:
location_name = STATE_NOT_HOME
yield from self.hass.loop.run_in_executor(
yield from hass.loop.run_in_executor(
None, partial(self.see, dev_id=device,
location_name=location_name,
gps=gps_location))

View File

@ -78,14 +78,13 @@ def setup(hass, yaml_config):
cors_origins=None,
use_x_forwarded_for=False,
trusted_networks=None,
ip_bans=None,
login_threshold=0,
is_ban_enabled=False
)
server.register_view(DescriptionXmlView(hass, config))
server.register_view(HueUsernameView(hass))
server.register_view(HueLightsView(hass, config))
server.register_view(DescriptionXmlView(config))
server.register_view(HueUsernameView)
server.register_view(HueLightsView(config))
upnp_listener = UPNPResponderThread(
config.host_ip_addr, config.listen_port)
@ -157,9 +156,8 @@ class DescriptionXmlView(HomeAssistantView):
name = 'description:xml'
requires_auth = False
def __init__(self, hass, config):
def __init__(self, config):
"""Initialize the instance of the view."""
super().__init__(hass)
self.config = config
@core.callback
@ -201,10 +199,6 @@ class HueUsernameView(HomeAssistantView):
extra_urls = ['/api/']
requires_auth = False
def __init__(self, hass):
"""Initialize the instance of the view."""
super().__init__(hass)
@asyncio.coroutine
def post(self, request):
"""Handle a POST request."""
@ -229,30 +223,33 @@ class HueLightsView(HomeAssistantView):
'/api/{username}/lights/{entity_id}/state']
requires_auth = False
def __init__(self, hass, config):
def __init__(self, config):
"""Initialize the instance of the view."""
super().__init__(hass)
self.config = config
self.cached_states = {}
@core.callback
def get(self, request, username, entity_id=None):
"""Handle a GET request."""
hass = request.app['hass']
if entity_id is None:
return self.async_get_lights_list()
return self.async_get_lights_list(hass)
if not request.path.endswith('state'):
return self.async_get_light_state(entity_id)
return self.async_get_light_state(hass, entity_id)
return web.Response(text="Method not allowed", status=405)
@asyncio.coroutine
def put(self, request, username, entity_id=None):
"""Handle a PUT request."""
hass = request.app['hass']
if not request.path.endswith('state'):
return web.Response(text="Method not allowed", status=405)
if entity_id and self.hass.states.get(entity_id) is None:
if entity_id and hass.states.get(entity_id) is None:
return self.json_message('Entity not found', HTTP_NOT_FOUND)
try:
@ -260,24 +257,25 @@ class HueLightsView(HomeAssistantView):
except ValueError:
return self.json_message('Invalid JSON', HTTP_BAD_REQUEST)
result = yield from self.async_put_light_state(json_data, entity_id)
result = yield from self.async_put_light_state(hass, json_data,
entity_id)
return result
@core.callback
def async_get_lights_list(self):
def async_get_lights_list(self, hass):
"""Process a request to get the list of available lights."""
json_response = {}
for entity in self.hass.states.async_all():
for entity in hass.states.async_all():
if self.is_entity_exposed(entity):
json_response[entity.entity_id] = entity_to_json(entity)
return self.json(json_response)
@core.callback
def async_get_light_state(self, entity_id):
def async_get_light_state(self, hass, entity_id):
"""Process a request to get the state of an individual light."""
entity = self.hass.states.get(entity_id)
entity = hass.states.get(entity_id)
if entity is None or not self.is_entity_exposed(entity):
return web.Response(text="Entity not found", status=404)
@ -295,12 +293,12 @@ class HueLightsView(HomeAssistantView):
return self.json(json_response)
@asyncio.coroutine
def async_put_light_state(self, request_json, entity_id):
def async_put_light_state(self, hass, request_json, entity_id):
"""Process a request to set the state of an individual light."""
config = self.config
# Retrieve the entity from the state machine
entity = self.hass.states.get(entity_id)
entity = hass.states.get(entity_id)
if entity is None:
return web.Response(text="Entity not found", status=404)
@ -345,8 +343,8 @@ class HueLightsView(HomeAssistantView):
self.cached_states[entity_id] = (result, brightness)
# Perform the requested action
yield from self.hass.services.async_call(core.DOMAIN, service, data,
blocking=True)
yield from hass.services.async_call(core.DOMAIN, service, data,
blocking=True)
json_response = \
[create_hue_success_response(entity_id, HUE_API_STATE_ON, result)]

View File

@ -75,8 +75,7 @@ def setup(hass, config):
descriptions[DOMAIN][SERVICE_CHECKIN],
schema=CHECKIN_SERVICE_SCHEMA)
hass.http.register_view(FoursquarePushReceiver(
hass, config[CONF_PUSH_SECRET]))
hass.http.register_view(FoursquarePushReceiver(config[CONF_PUSH_SECRET]))
return True
@ -88,9 +87,8 @@ class FoursquarePushReceiver(HomeAssistantView):
url = "/api/foursquare"
name = "foursquare"
def __init__(self, hass, push_secret):
def __init__(self, push_secret):
"""Initialize the OAuth callback view."""
super().__init__(hass)
self.push_secret = push_secret
@asyncio.coroutine
@ -110,4 +108,4 @@ class FoursquarePushReceiver(HomeAssistantView):
"push secret: %s", secret)
return self.json_message('Incorrect secret', HTTP_BAD_REQUEST)
self.hass.bus.async_fire(EVENT_PUSH, data)
request.app['hass'].bus.async_fire(EVENT_PUSH, data)

View File

@ -11,6 +11,8 @@ from homeassistant.core import callback
from homeassistant.const import HTTP_NOT_FOUND
from homeassistant.components import api, group
from homeassistant.components.http import HomeAssistantView
from homeassistant.components.http.auth import is_trusted_ip
from homeassistant.components.http.const import KEY_DEVELOPMENT
from .version import FINGERPRINTS
DOMAIN = 'frontend'
@ -155,7 +157,7 @@ def setup(hass, config):
if os.path.isdir(local):
hass.http.register_static_path("/local", local)
index_view = hass.data[DATA_INDEX_VIEW] = IndexView(hass)
index_view = hass.data[DATA_INDEX_VIEW] = IndexView()
hass.http.register_view(index_view)
# Components have registered panels before frontend got setup.
@ -185,12 +187,14 @@ class BootstrapView(HomeAssistantView):
@callback
def get(self, request):
"""Return all data needed to bootstrap Home Assistant."""
hass = request.app['hass']
return self.json({
'config': self.hass.config.as_dict(),
'states': self.hass.states.async_all(),
'events': api.async_events_json(self.hass),
'services': api.async_services_json(self.hass),
'panels': self.hass.data[DATA_PANELS],
'config': hass.config.as_dict(),
'states': hass.states.async_all(),
'events': api.async_events_json(hass),
'services': api.async_services_json(hass),
'panels': hass.data[DATA_PANELS],
})
@ -202,10 +206,8 @@ class IndexView(HomeAssistantView):
requires_auth = False
extra_urls = ['/states', '/states/{entity_id}']
def __init__(self, hass):
def __init__(self):
"""Initialize the frontend view."""
super().__init__(hass)
from jinja2 import FileSystemLoader, Environment
self.templates = Environment(
@ -217,14 +219,16 @@ class IndexView(HomeAssistantView):
@asyncio.coroutine
def get(self, request, entity_id=None):
"""Serve the index view."""
hass = request.app['hass']
if entity_id is not None:
state = self.hass.states.get(entity_id)
state = hass.states.get(entity_id)
if (not state or state.domain != 'group' or
not state.attributes.get(group.ATTR_VIEW)):
return self.json_message('Entity not found', HTTP_NOT_FOUND)
if self.hass.http.development:
if request.app[KEY_DEVELOPMENT]:
core_url = '/static/home-assistant-polymer/build/core.js'
ui_url = '/static/home-assistant-polymer/src/home-assistant.html'
else:
@ -241,19 +245,18 @@ class IndexView(HomeAssistantView):
if panel == 'states':
panel_url = ''
else:
panel_url = self.hass.data[DATA_PANELS][panel]['url']
panel_url = hass.data[DATA_PANELS][panel]['url']
no_auth = 'true'
if self.hass.config.api.api_password:
if hass.config.api.api_password:
# require password if set
no_auth = 'false'
if self.hass.http.is_trusted_ip(
self.hass.http.get_real_ip(request)):
if is_trusted_ip(request):
# bypass for trusted networks
no_auth = 'true'
icons_url = '/static/mdi-{}.html'.format(FINGERPRINTS['mdi.html'])
template = yield from self.hass.loop.run_in_executor(
template = yield from hass.loop.run_in_executor(
None, self.templates.get_template, 'index.html')
# pylint is wrong
@ -262,7 +265,7 @@ class IndexView(HomeAssistantView):
resp = template.render(
core_url=core_url, ui_url=ui_url, no_auth=no_auth,
icons_url=icons_url, icons=FINGERPRINTS['mdi.html'],
panel_url=panel_url, panels=self.hass.data[DATA_PANELS])
panel_url=panel_url, panels=hass.data[DATA_PANELS])
return web.Response(text=resp, content_type='text/html')

View File

@ -184,8 +184,8 @@ def setup(hass, config):
filters.included_entities = include[CONF_ENTITIES]
filters.included_domains = include[CONF_DOMAINS]
hass.http.register_view(Last5StatesView(hass))
hass.http.register_view(HistoryPeriodView(hass, filters))
hass.http.register_view(Last5StatesView)
hass.http.register_view(HistoryPeriodView(filters))
register_built_in_panel(hass, 'history', 'History', 'mdi:poll-box')
return True
@ -197,14 +197,10 @@ class Last5StatesView(HomeAssistantView):
url = '/api/history/entity/{entity_id}/recent_states'
name = 'api:history:entity-recent-states'
def __init__(self, hass):
"""Initilalize the history last 5 states view."""
super().__init__(hass)
@asyncio.coroutine
def get(self, request, entity_id):
"""Retrieve last 5 states of entity."""
result = yield from self.hass.loop.run_in_executor(
result = yield from request.app['hass'].loop.run_in_executor(
None, last_5_states, entity_id)
return self.json(result)
@ -216,9 +212,8 @@ class HistoryPeriodView(HomeAssistantView):
name = 'api:history:view-period'
extra_urls = ['/api/history/period/{datetime}']
def __init__(self, hass, filters):
def __init__(self, filters):
"""Initilalize the history period view."""
super().__init__(hass)
self.filters = filters
@asyncio.coroutine
@ -240,7 +235,7 @@ class HistoryPeriodView(HomeAssistantView):
end_time = start_time + one_day
entity_id = request.GET.get('filter_entity_id')
result = yield from self.hass.loop.run_in_executor(
result = yield from request.app['hass'].loop.run_in_executor(
None, get_significant_states, start_time, end_time, entity_id,
self.filters)

View File

@ -1,641 +0,0 @@
"""
This module provides WSGI application to serve the Home Assistant API.
For more details about this component, please refer to the documentation at
https://home-assistant.io/components/http/
"""
import asyncio
import json
import logging
import mimetypes
import ssl
from datetime import datetime
from ipaddress import ip_address, ip_network
from pathlib import Path
import hmac
import os
import re
import voluptuous as vol
from aiohttp import web, hdrs
from aiohttp.file_sender import FileSender
from aiohttp.web_exceptions import (
HTTPUnauthorized, HTTPMovedPermanently, HTTPNotModified, HTTPForbidden)
from aiohttp.web_urldispatcher import StaticResource
import homeassistant.helpers.config_validation as cv
import homeassistant.remote as rem
from homeassistant import util
from homeassistant.components import persistent_notification
from homeassistant.config import load_yaml_config_file
from homeassistant.const import (
SERVER_PORT, HTTP_HEADER_HA_AUTH, # HTTP_HEADER_CACHE_CONTROL,
CONTENT_TYPE_JSON, ALLOWED_CORS_HEADERS, EVENT_HOMEASSISTANT_STOP,
EVENT_HOMEASSISTANT_START, HTTP_HEADER_X_FORWARDED_FOR)
from homeassistant.core import is_callback
from homeassistant.exceptions import HomeAssistantError
from homeassistant.util.yaml import dump
DOMAIN = 'http'
REQUIREMENTS = ('aiohttp_cors==0.5.0',)
CONF_API_PASSWORD = 'api_password'
CONF_SERVER_HOST = 'server_host'
CONF_SERVER_PORT = 'server_port'
CONF_DEVELOPMENT = 'development'
CONF_SSL_CERTIFICATE = 'ssl_certificate'
CONF_SSL_KEY = 'ssl_key'
CONF_CORS_ORIGINS = 'cors_allowed_origins'
CONF_USE_X_FORWARDED_FOR = 'use_x_forwarded_for'
CONF_TRUSTED_NETWORKS = 'trusted_networks'
CONF_LOGIN_ATTEMPTS_THRESHOLD = 'login_attempts_threshold'
CONF_IP_BAN_ENABLED = 'ip_ban_enabled'
DATA_API_PASSWORD = 'api_password'
NOTIFICATION_ID_LOGIN = 'http-login'
NOTIFICATION_ID_BAN = 'ip-ban'
IP_BANS = 'ip_bans.yaml'
ATTR_BANNED_AT = "banned_at"
# TLS configuation follows the best-practice guidelines specified here:
# https://wiki.mozilla.org/Security/Server_Side_TLS
# Intermediate guidelines are followed.
SSL_VERSION = ssl.PROTOCOL_SSLv23
SSL_OPTS = ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3
if hasattr(ssl, 'OP_NO_COMPRESSION'):
SSL_OPTS |= ssl.OP_NO_COMPRESSION
CIPHERS = "ECDHE-ECDSA-CHACHA20-POLY1305:ECDHE-RSA-CHACHA20-POLY1305:" \
"ECDHE-ECDSA-AES128-GCM-SHA256:ECDHE-RSA-AES128-GCM-SHA256:" \
"ECDHE-ECDSA-AES256-GCM-SHA384:ECDHE-RSA-AES256-GCM-SHA384:" \
"DHE-RSA-AES128-GCM-SHA256:DHE-RSA-AES256-GCM-SHA384:" \
"ECDHE-ECDSA-AES128-SHA256:ECDHE-RSA-AES128-SHA256:" \
"ECDHE-ECDSA-AES128-SHA:ECDHE-RSA-AES256-SHA384:" \
"ECDHE-RSA-AES128-SHA:ECDHE-ECDSA-AES256-SHA384:" \
"ECDHE-ECDSA-AES256-SHA:ECDHE-RSA-AES256-SHA:" \
"DHE-RSA-AES128-SHA256:DHE-RSA-AES128-SHA:DHE-RSA-AES256-SHA256:" \
"DHE-RSA-AES256-SHA:ECDHE-ECDSA-DES-CBC3-SHA:" \
"ECDHE-RSA-DES-CBC3-SHA:EDH-RSA-DES-CBC3-SHA:AES128-GCM-SHA256:" \
"AES256-GCM-SHA384:AES128-SHA256:AES256-SHA256:AES128-SHA:" \
"AES256-SHA:DES-CBC3-SHA:!DSS"
_FINGERPRINT = re.compile(r'^(.+)-[a-z0-9]{32}\.(\w+)$', re.IGNORECASE)
_LOGGER = logging.getLogger(__name__)
CONFIG_SCHEMA = vol.Schema({
DOMAIN: vol.Schema({
vol.Optional(CONF_API_PASSWORD): cv.string,
vol.Optional(CONF_SERVER_HOST): cv.string,
vol.Optional(CONF_SERVER_PORT, default=SERVER_PORT):
vol.All(vol.Coerce(int), vol.Range(min=1, max=65535)),
vol.Optional(CONF_DEVELOPMENT): cv.string,
vol.Optional(CONF_SSL_CERTIFICATE): cv.isfile,
vol.Optional(CONF_SSL_KEY): cv.isfile,
vol.Optional(CONF_CORS_ORIGINS): vol.All(cv.ensure_list, [cv.string]),
vol.Optional(CONF_USE_X_FORWARDED_FOR, default=False): cv.boolean,
vol.Optional(CONF_TRUSTED_NETWORKS):
vol.All(cv.ensure_list, [ip_network]),
vol.Optional(CONF_LOGIN_ATTEMPTS_THRESHOLD): cv.positive_int,
vol.Optional(CONF_IP_BAN_ENABLED): cv.boolean
}),
}, extra=vol.ALLOW_EXTRA)
# TEMP TO GET TESTS TO RUN
def request_class():
"""."""
raise Exception('not implemented')
class HideSensitiveFilter(logging.Filter):
"""Filter API password calls."""
def __init__(self, hass):
"""Initialize sensitive data filter."""
super().__init__()
self.hass = hass
def filter(self, record):
"""Hide sensitive data in messages."""
if self.hass.http.api_password is None:
return True
record.msg = record.msg.replace(self.hass.http.api_password, '*******')
return True
def setup(hass, config):
"""Set up the HTTP API and debug interface."""
logging.getLogger('aiohttp.access').addFilter(HideSensitiveFilter(hass))
conf = config.get(DOMAIN, {})
api_password = util.convert(conf.get(CONF_API_PASSWORD), str)
server_host = conf.get(CONF_SERVER_HOST, '0.0.0.0')
server_port = conf.get(CONF_SERVER_PORT, SERVER_PORT)
development = str(conf.get(CONF_DEVELOPMENT, '')) == '1'
ssl_certificate = conf.get(CONF_SSL_CERTIFICATE)
ssl_key = conf.get(CONF_SSL_KEY)
cors_origins = conf.get(CONF_CORS_ORIGINS, [])
use_x_forwarded_for = conf.get(CONF_USE_X_FORWARDED_FOR, False)
trusted_networks = [
ip_network(trusted_network)
for trusted_network in conf.get(CONF_TRUSTED_NETWORKS, [])]
is_ban_enabled = bool(conf.get(CONF_IP_BAN_ENABLED, False))
login_threshold = int(conf.get(CONF_LOGIN_ATTEMPTS_THRESHOLD, -1))
ip_bans = load_ip_bans_config(hass.config.path(IP_BANS))
server = HomeAssistantWSGI(
hass,
development=development,
server_host=server_host,
server_port=server_port,
api_password=api_password,
ssl_certificate=ssl_certificate,
ssl_key=ssl_key,
cors_origins=cors_origins,
use_x_forwarded_for=use_x_forwarded_for,
trusted_networks=trusted_networks,
ip_bans=ip_bans,
login_threshold=login_threshold,
is_ban_enabled=is_ban_enabled
)
@asyncio.coroutine
def stop_server(event):
"""Callback to stop the server."""
yield from server.stop()
@asyncio.coroutine
def start_server(event):
"""Callback to start the server."""
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, stop_server)
yield from server.start()
hass.bus.listen_once(EVENT_HOMEASSISTANT_START, start_server)
hass.http = server
hass.config.api = rem.API(server_host if server_host != '0.0.0.0'
else util.get_local_ip(),
api_password, server_port,
ssl_certificate is not None)
return True
class GzipFileSender(FileSender):
"""FileSender class capable of sending gzip version if available."""
# pylint: disable=invalid-name
development = False
@asyncio.coroutine
def send(self, request, filepath):
"""Send filepath to client using request."""
gzip = False
if 'gzip' in request.headers[hdrs.ACCEPT_ENCODING]:
gzip_path = filepath.with_name(filepath.name + '.gz')
if gzip_path.is_file():
filepath = gzip_path
gzip = True
st = filepath.stat()
modsince = request.if_modified_since
if modsince is not None and st.st_mtime <= modsince.timestamp():
raise HTTPNotModified()
ct, encoding = mimetypes.guess_type(str(filepath))
if not ct:
ct = 'application/octet-stream'
resp = self._response_factory()
resp.content_type = ct
if encoding:
resp.headers[hdrs.CONTENT_ENCODING] = encoding
if gzip:
resp.headers[hdrs.VARY] = hdrs.ACCEPT_ENCODING
resp.last_modified = st.st_mtime
# CACHE HACK
if not self.development:
cache_time = 31 * 86400 # = 1 month
resp.headers[hdrs.CACHE_CONTROL] = "public, max-age={}".format(
cache_time)
file_size = st.st_size
resp.content_length = file_size
with filepath.open('rb') as f:
yield from self._sendfile(request, resp, f, file_size)
return resp
_GZIP_FILE_SENDER = GzipFileSender()
@asyncio.coroutine
def staticresource_enhancer(app, handler):
"""Enhance StaticResourceHandler.
Adds gzip encoding and fingerprinting matching.
"""
inst = getattr(handler, '__self__', None)
if not isinstance(inst, StaticResource):
return handler
# pylint: disable=protected-access
inst._file_sender = _GZIP_FILE_SENDER
@asyncio.coroutine
def middleware_handler(request):
"""Strip out fingerprints from resource names."""
fingerprinted = _FINGERPRINT.match(request.match_info['filename'])
if fingerprinted:
request.match_info['filename'] = \
'{}.{}'.format(*fingerprinted.groups())
resp = yield from handler(request)
return resp
return middleware_handler
class HomeAssistantWSGI(object):
"""WSGI server for Home Assistant."""
def __init__(self, hass, development, api_password, ssl_certificate,
ssl_key, server_host, server_port, cors_origins,
use_x_forwarded_for, trusted_networks,
ip_bans, login_threshold, is_ban_enabled):
"""Initialize the WSGI Home Assistant server."""
import aiohttp_cors
self.app = web.Application(middlewares=[staticresource_enhancer],
loop=hass.loop)
self.hass = hass
self.development = development
self.api_password = api_password
self.ssl_certificate = ssl_certificate
self.ssl_key = ssl_key
self.server_host = server_host
self.server_port = server_port
self.use_x_forwarded_for = use_x_forwarded_for
self.trusted_networks = trusted_networks \
if trusted_networks is not None else []
self.event_forwarder = None
self._handler = None
self.server = None
self.login_threshold = login_threshold
self.ip_bans = ip_bans if ip_bans is not None else []
self.failed_login_attempts = {}
self.is_ban_enabled = is_ban_enabled
if cors_origins:
self.cors = aiohttp_cors.setup(self.app, defaults={
host: aiohttp_cors.ResourceOptions(
allow_headers=ALLOWED_CORS_HEADERS,
allow_methods='*',
) for host in cors_origins
})
else:
self.cors = None
# CACHE HACK
_GZIP_FILE_SENDER.development = development
def register_view(self, view):
"""Register a view with the WSGI server.
The view argument must be a class that inherits from HomeAssistantView.
It is optional to instantiate it before registering; this method will
handle it either way.
"""
if isinstance(view, type):
# Instantiate the view, if needed
view = view(self.hass)
view.register(self.app.router)
def register_redirect(self, url, redirect_to):
"""Register a redirect with the server.
If given this must be either a string or callable. In case of a
callable it's called with the url adapter that triggered the match and
the values of the URL as keyword arguments and has to return the target
for the redirect, otherwise it has to be a string with placeholders in
rule syntax.
"""
def redirect(request):
"""Redirect to location."""
raise HTTPMovedPermanently(redirect_to)
self.app.router.add_route('GET', url, redirect)
def register_static_path(self, url_root, path, cache_length=31):
"""Register a folder to serve as a static path.
Specify optional cache length of asset in days.
"""
if os.path.isdir(path):
self.app.router.add_static(url_root, path)
return
filepath = Path(path)
@asyncio.coroutine
def serve_file(request):
"""Redirect to location."""
res = yield from _GZIP_FILE_SENDER.send(request, filepath)
return res
# aiohttp supports regex matching for variables. Using that as temp
# to work around cache busting MD5.
# Turns something like /static/dev-panel.html into
# /static/{filename:dev-panel(-[a-z0-9]{32}|)\.html}
base, ext = url_root.rsplit('.', 1)
base, file = base.rsplit('/', 1)
regex = r"{}(-[a-z0-9]{{32}}|)\.{}".format(file, ext)
url_pattern = "{}/{{filename:{}}}".format(base, regex)
self.app.router.add_route('GET', url_pattern, serve_file)
@asyncio.coroutine
def start(self):
"""Start the wsgi server."""
if self.cors is not None:
for route in list(self.app.router.routes()):
self.cors.add(route)
if self.ssl_certificate:
context = ssl.SSLContext(SSL_VERSION)
context.options |= SSL_OPTS
context.set_ciphers(CIPHERS)
context.load_cert_chain(self.ssl_certificate, self.ssl_key)
else:
context = None
self._handler = self.app.make_handler()
self.server = yield from self.hass.loop.create_server(
self._handler, self.server_host, self.server_port, ssl=context)
@asyncio.coroutine
def stop(self):
"""Stop the wsgi server."""
self.server.close()
yield from self.server.wait_closed()
yield from self.app.shutdown()
yield from self._handler.finish_connections(60.0)
yield from self.app.cleanup()
def get_real_ip(self, request):
"""Return the clients correct ip address, even in proxied setups."""
if self.use_x_forwarded_for \
and HTTP_HEADER_X_FORWARDED_FOR in request.headers:
return request.headers.get(
HTTP_HEADER_X_FORWARDED_FOR).split(',')[0]
else:
peername = request.transport.get_extra_info('peername')
return peername[0] if peername is not None else None
def is_trusted_ip(self, remote_addr):
"""Match an ip address against trusted CIDR networks."""
return any(ip_address(remote_addr) in trusted_network
for trusted_network in self.hass.http.trusted_networks)
def wrong_login_attempt(self, remote_addr):
"""Registering wrong login attempt."""
if not self.is_ban_enabled or self.login_threshold < 1:
return
if remote_addr in self.failed_login_attempts:
self.failed_login_attempts[remote_addr] += 1
else:
self.failed_login_attempts[remote_addr] = 1
if self.failed_login_attempts[remote_addr] > self.login_threshold:
new_ban = IpBan(remote_addr)
self.ip_bans.append(new_ban)
update_ip_bans_config(self.hass.config.path(IP_BANS), new_ban)
_LOGGER.warning('Banned IP %s for too many login attempts',
remote_addr)
persistent_notification.async_create(
self.hass,
'Too many login attempts from {}'.format(remote_addr),
'Banning IP address', NOTIFICATION_ID_BAN)
def is_banned_ip(self, remote_addr):
"""Check if IP address is in a ban list."""
if not self.is_ban_enabled:
return False
ip_address_ = ip_address(remote_addr)
for ip_ban in self.ip_bans:
if ip_ban.ip_address == ip_address_:
return True
return False
class HomeAssistantView(object):
"""Base view for all views."""
url = None
extra_urls = []
requires_auth = True # Views inheriting from this class can override this
def __init__(self, hass):
"""Initilalize the base view."""
if not hasattr(self, 'url'):
class_name = self.__class__.__name__
raise AttributeError(
'{0} missing required attribute "url"'.format(class_name)
)
if not hasattr(self, 'name'):
class_name = self.__class__.__name__
raise AttributeError(
'{0} missing required attribute "name"'.format(class_name)
)
self.hass = hass
# pylint: disable=no-self-use
def json(self, result, status_code=200):
"""Return a JSON response."""
msg = json.dumps(
result, sort_keys=True, cls=rem.JSONEncoder).encode('UTF-8')
return web.Response(
body=msg, content_type=CONTENT_TYPE_JSON, status=status_code)
def json_message(self, error, status_code=200):
"""Return a JSON message response."""
return self.json({'message': error}, status_code)
@asyncio.coroutine
# pylint: disable=no-self-use
def file(self, request, fil):
"""Return a file."""
assert isinstance(fil, str), 'only string paths allowed'
response = yield from _GZIP_FILE_SENDER.send(request, Path(fil))
return response
def register(self, router):
"""Register the view with a router."""
assert self.url is not None, 'No url set for view'
urls = [self.url] + self.extra_urls
for method in ('get', 'post', 'delete', 'put'):
handler = getattr(self, method, None)
if not handler:
continue
handler = request_handler_factory(self, handler)
for url in urls:
router.add_route(method, url, handler)
# aiohttp_cors does not work with class based views
# self.app.router.add_route('*', self.url, self, name=self.name)
# for url in self.extra_urls:
# self.app.router.add_route('*', url, self)
def request_handler_factory(view, handler):
"""Factory to wrap our handler classes.
Eventually authentication should be managed by middleware.
"""
@asyncio.coroutine
def handle(request):
"""Handle incoming request."""
if not view.hass.is_running:
return web.Response(status=503)
remote_addr = view.hass.http.get_real_ip(request)
if view.hass.http.is_banned_ip(remote_addr):
raise HTTPForbidden()
# Auth code verbose on purpose
authenticated = False
if view.hass.http.api_password is None:
authenticated = True
elif view.hass.http.is_trusted_ip(remote_addr):
authenticated = True
elif hmac.compare_digest(request.headers.get(HTTP_HEADER_HA_AUTH, ''),
view.hass.http.api_password):
# A valid auth header has been set
authenticated = True
elif hmac.compare_digest(request.GET.get(DATA_API_PASSWORD, ''),
view.hass.http.api_password):
authenticated = True
if view.requires_auth and not authenticated:
view.hass.http.wrong_login_attempt(remote_addr)
_LOGGER.warning('Login attempt or request with an invalid '
'password from %s', remote_addr)
persistent_notification.async_create(
view.hass,
'Invalid password used from {}'.format(remote_addr),
'Login attempt failed', NOTIFICATION_ID_LOGIN)
raise HTTPUnauthorized()
request.authenticated = authenticated
_LOGGER.info('Serving %s to %s (auth: %s)',
request.path, remote_addr, authenticated)
assert asyncio.iscoroutinefunction(handler) or is_callback(handler), \
"Handler should be a coroutine or a callback."
result = handler(request, **request.match_info)
if asyncio.iscoroutine(result):
result = yield from result
if isinstance(result, web.StreamResponse):
# The method handler returned a ready-made Response, how nice of it
return result
status_code = 200
if isinstance(result, tuple):
result, status_code = result
if isinstance(result, str):
result = result.encode('utf-8')
elif result is None:
result = b''
elif not isinstance(result, bytes):
assert False, ('Result should be None, string, bytes or Response. '
'Got: {}').format(result)
return web.Response(body=result, status=status_code)
return handle
class IpBan(object):
"""Represents banned IP address."""
def __init__(self, ip_ban: str, banned_at: datetime=None) -> None:
"""Initializing Ip Ban object."""
self.ip_address = ip_address(ip_ban)
self.banned_at = banned_at
if self.banned_at is None:
self.banned_at = datetime.utcnow()
def load_ip_bans_config(path: str):
"""Loading list of banned IPs from config file."""
ip_list = []
ip_schema = vol.Schema({
vol.Optional('banned_at'): vol.Any(None, cv.datetime)
})
try:
try:
list_ = load_yaml_config_file(path)
except HomeAssistantError as err:
_LOGGER.error('Unable to load %s: %s', path, str(err))
return []
for ip_ban, ip_info in list_.items():
try:
ip_info = ip_schema(ip_info)
ip_info['ip_ban'] = ip_address(ip_ban)
ip_list.append(IpBan(**ip_info))
except vol.Invalid:
_LOGGER.exception('Failed to load IP ban')
continue
except(HomeAssistantError, FileNotFoundError):
# No need to report error, file absence means
# that no bans were applied.
return []
return ip_list
def update_ip_bans_config(path: str, ip_ban: IpBan):
"""Update config file with new banned IP address."""
with open(path, 'a') as out:
ip_ = {str(ip_ban.ip_address): {
ATTR_BANNED_AT: ip_ban.banned_at.strftime("%Y-%m-%dT%H:%M:%S")
}}
out.write('\n')
out.write(dump(ip_))

View File

@ -0,0 +1,407 @@
"""
This module provides WSGI application to serve the Home Assistant API.
For more details about this component, please refer to the documentation at
https://home-assistant.io/components/http/
"""
import asyncio
import json
import logging
import ssl
from ipaddress import ip_network
from pathlib import Path
import os
import voluptuous as vol
from aiohttp import web
from aiohttp.web_exceptions import HTTPUnauthorized, HTTPMovedPermanently
import homeassistant.helpers.config_validation as cv
import homeassistant.remote as rem
from homeassistant.util import get_local_ip
from homeassistant.components import persistent_notification
from homeassistant.const import (
SERVER_PORT, CONTENT_TYPE_JSON, ALLOWED_CORS_HEADERS,
EVENT_HOMEASSISTANT_STOP, EVENT_HOMEASSISTANT_START)
from homeassistant.core import is_callback
from homeassistant.util.logging import HideSensitiveDataFilter
from .auth import auth_middleware
from .ban import ban_middleware, process_wrong_login
from .const import (
KEY_USE_X_FORWARDED_FOR, KEY_TRUSTED_NETWORKS,
KEY_BANS_ENABLED, KEY_LOGIN_THRESHOLD,
KEY_DEVELOPMENT, KEY_AUTHENTICATED)
from .static import GZIP_FILE_SENDER, staticresource_middleware
from .util import get_real_ip
DOMAIN = 'http'
REQUIREMENTS = ('aiohttp_cors==0.5.0',)
CONF_API_PASSWORD = 'api_password'
CONF_SERVER_HOST = 'server_host'
CONF_SERVER_PORT = 'server_port'
CONF_DEVELOPMENT = 'development'
CONF_SSL_CERTIFICATE = 'ssl_certificate'
CONF_SSL_KEY = 'ssl_key'
CONF_CORS_ORIGINS = 'cors_allowed_origins'
CONF_USE_X_FORWARDED_FOR = 'use_x_forwarded_for'
CONF_TRUSTED_NETWORKS = 'trusted_networks'
CONF_LOGIN_ATTEMPTS_THRESHOLD = 'login_attempts_threshold'
CONF_IP_BAN_ENABLED = 'ip_ban_enabled'
NOTIFICATION_ID_LOGIN = 'http-login'
# TLS configuation follows the best-practice guidelines specified here:
# https://wiki.mozilla.org/Security/Server_Side_TLS
# Intermediate guidelines are followed.
SSL_VERSION = ssl.PROTOCOL_SSLv23
SSL_OPTS = ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3
if hasattr(ssl, 'OP_NO_COMPRESSION'):
SSL_OPTS |= ssl.OP_NO_COMPRESSION
CIPHERS = "ECDHE-ECDSA-CHACHA20-POLY1305:ECDHE-RSA-CHACHA20-POLY1305:" \
"ECDHE-ECDSA-AES128-GCM-SHA256:ECDHE-RSA-AES128-GCM-SHA256:" \
"ECDHE-ECDSA-AES256-GCM-SHA384:ECDHE-RSA-AES256-GCM-SHA384:" \
"DHE-RSA-AES128-GCM-SHA256:DHE-RSA-AES256-GCM-SHA384:" \
"ECDHE-ECDSA-AES128-SHA256:ECDHE-RSA-AES128-SHA256:" \
"ECDHE-ECDSA-AES128-SHA:ECDHE-RSA-AES256-SHA384:" \
"ECDHE-RSA-AES128-SHA:ECDHE-ECDSA-AES256-SHA384:" \
"ECDHE-ECDSA-AES256-SHA:ECDHE-RSA-AES256-SHA:" \
"DHE-RSA-AES128-SHA256:DHE-RSA-AES128-SHA:DHE-RSA-AES256-SHA256:" \
"DHE-RSA-AES256-SHA:ECDHE-ECDSA-DES-CBC3-SHA:" \
"ECDHE-RSA-DES-CBC3-SHA:EDH-RSA-DES-CBC3-SHA:AES128-GCM-SHA256:" \
"AES256-GCM-SHA384:AES128-SHA256:AES256-SHA256:AES128-SHA:" \
"AES256-SHA:DES-CBC3-SHA:!DSS"
_LOGGER = logging.getLogger(__name__)
DEFAULT_SERVER_HOST = '0.0.0.0'
DEFAULT_DEVELOPMENT = '0'
DEFAULT_LOGIN_ATTEMPT_THRESHOLD = -1
HTTP_SCHEMA = vol.Schema({
vol.Optional(CONF_API_PASSWORD, default=None): cv.string,
vol.Optional(CONF_SERVER_HOST, default=DEFAULT_SERVER_HOST): cv.string,
vol.Optional(CONF_SERVER_PORT, default=SERVER_PORT):
vol.All(vol.Coerce(int), vol.Range(min=1, max=65535)),
vol.Optional(CONF_DEVELOPMENT, default=DEFAULT_DEVELOPMENT): cv.string,
vol.Optional(CONF_SSL_CERTIFICATE, default=None): cv.isfile,
vol.Optional(CONF_SSL_KEY, default=None): cv.isfile,
vol.Optional(CONF_CORS_ORIGINS, default=[]): vol.All(cv.ensure_list,
[cv.string]),
vol.Optional(CONF_USE_X_FORWARDED_FOR, default=False): cv.boolean,
vol.Optional(CONF_TRUSTED_NETWORKS, default=[]):
vol.All(cv.ensure_list, [ip_network]),
vol.Optional(CONF_LOGIN_ATTEMPTS_THRESHOLD,
default=DEFAULT_LOGIN_ATTEMPT_THRESHOLD): cv.positive_int,
vol.Optional(CONF_IP_BAN_ENABLED, default=True): cv.boolean
})
CONFIG_SCHEMA = vol.Schema({
DOMAIN: HTTP_SCHEMA,
}, extra=vol.ALLOW_EXTRA)
@asyncio.coroutine
def async_setup(hass, config):
"""Set up the HTTP API and debug interface."""
conf = config.get(DOMAIN)
if conf is None:
conf = HTTP_SCHEMA({})
api_password = conf[CONF_API_PASSWORD]
server_host = conf[CONF_SERVER_HOST]
server_port = conf[CONF_SERVER_PORT]
development = conf[CONF_DEVELOPMENT] == '1'
ssl_certificate = conf[CONF_SSL_CERTIFICATE]
ssl_key = conf[CONF_SSL_KEY]
cors_origins = conf[CONF_CORS_ORIGINS]
use_x_forwarded_for = conf[CONF_USE_X_FORWARDED_FOR]
trusted_networks = conf[CONF_TRUSTED_NETWORKS]
is_ban_enabled = conf[CONF_IP_BAN_ENABLED]
login_threshold = conf[CONF_LOGIN_ATTEMPTS_THRESHOLD]
if api_password is not None:
logging.getLogger('aiohttp.access').addFilter(
HideSensitiveDataFilter(api_password))
server = HomeAssistantWSGI(
hass,
development=development,
server_host=server_host,
server_port=server_port,
api_password=api_password,
ssl_certificate=ssl_certificate,
ssl_key=ssl_key,
cors_origins=cors_origins,
use_x_forwarded_for=use_x_forwarded_for,
trusted_networks=trusted_networks,
login_threshold=login_threshold,
is_ban_enabled=is_ban_enabled
)
@asyncio.coroutine
def stop_server(event):
"""Callback to stop the server."""
yield from server.stop()
@asyncio.coroutine
def start_server(event):
"""Callback to start the server."""
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, stop_server)
yield from server.start()
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_START, start_server)
hass.http = server
hass.config.api = rem.API(server_host if server_host != '0.0.0.0'
else get_local_ip(),
api_password, server_port,
ssl_certificate is not None)
return True
class HomeAssistantWSGI(object):
"""WSGI server for Home Assistant."""
def __init__(self, hass, development, api_password, ssl_certificate,
ssl_key, server_host, server_port, cors_origins,
use_x_forwarded_for, trusted_networks,
login_threshold, is_ban_enabled):
"""Initialize the WSGI Home Assistant server."""
import aiohttp_cors
middlewares = [auth_middleware, staticresource_middleware]
if is_ban_enabled:
middlewares.insert(0, ban_middleware)
self.app = web.Application(middlewares=middlewares, loop=hass.loop)
self.app['hass'] = hass
self.app[KEY_USE_X_FORWARDED_FOR] = use_x_forwarded_for
self.app[KEY_TRUSTED_NETWORKS] = trusted_networks
self.app[KEY_BANS_ENABLED] = is_ban_enabled
self.app[KEY_LOGIN_THRESHOLD] = login_threshold
self.app[KEY_DEVELOPMENT] = development
self.hass = hass
self.development = development
self.api_password = api_password
self.ssl_certificate = ssl_certificate
self.ssl_key = ssl_key
self.server_host = server_host
self.server_port = server_port
self._handler = None
self.server = None
if cors_origins:
self.cors = aiohttp_cors.setup(self.app, defaults={
host: aiohttp_cors.ResourceOptions(
allow_headers=ALLOWED_CORS_HEADERS,
allow_methods='*',
) for host in cors_origins
})
else:
self.cors = None
def register_view(self, view):
"""Register a view with the WSGI server.
The view argument must be a class that inherits from HomeAssistantView.
It is optional to instantiate it before registering; this method will
handle it either way.
"""
if isinstance(view, type):
# Instantiate the view, if needed
view = view()
if not hasattr(view, 'url'):
class_name = view.__class__.__name__
raise AttributeError(
'{0} missing required attribute "url"'.format(class_name)
)
if not hasattr(view, 'name'):
class_name = view.__class__.__name__
raise AttributeError(
'{0} missing required attribute "name"'.format(class_name)
)
view.register(self.app.router)
def register_redirect(self, url, redirect_to):
"""Register a redirect with the server.
If given this must be either a string or callable. In case of a
callable it's called with the url adapter that triggered the match and
the values of the URL as keyword arguments and has to return the target
for the redirect, otherwise it has to be a string with placeholders in
rule syntax.
"""
def redirect(request):
"""Redirect to location."""
raise HTTPMovedPermanently(redirect_to)
self.app.router.add_route('GET', url, redirect)
def register_static_path(self, url_root, path, cache_length=31):
"""Register a folder to serve as a static path.
Specify optional cache length of asset in days.
"""
if os.path.isdir(path):
self.app.router.add_static(url_root, path)
return
filepath = Path(path)
@asyncio.coroutine
def serve_file(request):
"""Serve file from disk."""
res = yield from GZIP_FILE_SENDER.send(request, filepath)
return res
# aiohttp supports regex matching for variables. Using that as temp
# to work around cache busting MD5.
# Turns something like /static/dev-panel.html into
# /static/{filename:dev-panel(-[a-z0-9]{32}|)\.html}
base, ext = url_root.rsplit('.', 1)
base, file = base.rsplit('/', 1)
regex = r"{}(-[a-z0-9]{{32}}|)\.{}".format(file, ext)
url_pattern = "{}/{{filename:{}}}".format(base, regex)
self.app.router.add_route('GET', url_pattern, serve_file)
@asyncio.coroutine
def start(self):
"""Start the wsgi server."""
if self.cors is not None:
for route in list(self.app.router.routes()):
self.cors.add(route)
if self.ssl_certificate:
context = ssl.SSLContext(SSL_VERSION)
context.options |= SSL_OPTS
context.set_ciphers(CIPHERS)
context.load_cert_chain(self.ssl_certificate, self.ssl_key)
else:
context = None
self._handler = self.app.make_handler()
self.server = yield from self.hass.loop.create_server(
self._handler, self.server_host, self.server_port, ssl=context)
@asyncio.coroutine
def stop(self):
"""Stop the wsgi server."""
self.server.close()
yield from self.server.wait_closed()
yield from self.app.shutdown()
yield from self._handler.finish_connections(60.0)
yield from self.app.cleanup()
class HomeAssistantView(object):
"""Base view for all views."""
url = None
extra_urls = []
requires_auth = True # Views inheriting from this class can override this
# pylint: disable=no-self-use
def json(self, result, status_code=200):
"""Return a JSON response."""
msg = json.dumps(
result, sort_keys=True, cls=rem.JSONEncoder).encode('UTF-8')
return web.Response(
body=msg, content_type=CONTENT_TYPE_JSON, status=status_code)
def json_message(self, error, status_code=200):
"""Return a JSON message response."""
return self.json({'message': error}, status_code)
@asyncio.coroutine
# pylint: disable=no-self-use
def file(self, request, fil):
"""Return a file."""
assert isinstance(fil, str), 'only string paths allowed'
response = yield from GZIP_FILE_SENDER.send(request, Path(fil))
return response
def register(self, router):
"""Register the view with a router."""
assert self.url is not None, 'No url set for view'
urls = [self.url] + self.extra_urls
for method in ('get', 'post', 'delete', 'put'):
handler = getattr(self, method, None)
if not handler:
continue
handler = request_handler_factory(self, handler)
for url in urls:
router.add_route(method, url, handler)
# aiohttp_cors does not work with class based views
# self.app.router.add_route('*', self.url, self, name=self.name)
# for url in self.extra_urls:
# self.app.router.add_route('*', url, self)
def request_handler_factory(view, handler):
"""Factory to wrap our handler classes."""
assert asyncio.iscoroutinefunction(handler) or is_callback(handler), \
"Handler should be a coroutine or a callback."
@asyncio.coroutine
def handle(request):
"""Handle incoming request."""
if not request.app['hass'].is_running:
return web.Response(status=503)
remote_addr = get_real_ip(request)
authenticated = request.get(KEY_AUTHENTICATED, False)
if view.requires_auth and not authenticated:
yield from process_wrong_login(request)
_LOGGER.warning('Login attempt or request with an invalid '
'password from %s', remote_addr)
persistent_notification.async_create(
request.app['hass'],
'Invalid password used from {}'.format(remote_addr),
'Login attempt failed', NOTIFICATION_ID_LOGIN)
raise HTTPUnauthorized()
_LOGGER.info('Serving %s to %s (auth: %s)',
request.path, remote_addr, authenticated)
result = handler(request, **request.match_info)
if asyncio.iscoroutine(result):
result = yield from result
if isinstance(result, web.StreamResponse):
# The method handler returned a ready-made Response, how nice of it
return result
status_code = 200
if isinstance(result, tuple):
result, status_code = result
if isinstance(result, str):
result = result.encode('utf-8')
elif result is None:
result = b''
elif not isinstance(result, bytes):
assert False, ('Result should be None, string, bytes or Response. '
'Got: {}').format(result)
return web.Response(body=result, status=status_code)
return handle

View File

@ -0,0 +1,61 @@
"""Authentication for HTTP component."""
import asyncio
import hmac
import logging
from homeassistant.const import HTTP_HEADER_HA_AUTH
from .util import get_real_ip
from .const import KEY_TRUSTED_NETWORKS, KEY_AUTHENTICATED
DATA_API_PASSWORD = 'api_password'
_LOGGER = logging.getLogger(__name__)
@asyncio.coroutine
def auth_middleware(app, handler):
"""Authentication middleware."""
# If no password set, just always set authenticated=True
if app['hass'].http.api_password is None:
@asyncio.coroutine
def no_auth_middleware_handler(request):
"""Auth middleware to approve all requests."""
request[KEY_AUTHENTICATED] = True
return handler(request)
return no_auth_middleware_handler
@asyncio.coroutine
def auth_middleware_handler(request):
"""Auth middleware to check authentication."""
hass = app['hass']
# Auth code verbose on purpose
authenticated = False
if hmac.compare_digest(request.headers.get(HTTP_HEADER_HA_AUTH, ''),
hass.http.api_password):
# A valid auth header has been set
authenticated = True
elif hmac.compare_digest(request.GET.get(DATA_API_PASSWORD, ''),
hass.http.api_password):
authenticated = True
elif is_trusted_ip(request):
authenticated = True
request[KEY_AUTHENTICATED] = authenticated
return handler(request)
return auth_middleware_handler
def is_trusted_ip(request):
"""Test if request is from a trusted ip."""
ip_addr = get_real_ip(request)
return ip_addr and any(
ip_addr in trusted_network for trusted_network
in request.app[KEY_TRUSTED_NETWORKS])

View File

@ -0,0 +1,132 @@
"""Ban logic for HTTP component."""
import asyncio
from collections import defaultdict
from datetime import datetime
from ipaddress import ip_address
import logging
from aiohttp.web_exceptions import HTTPForbidden
import voluptuous as vol
from homeassistant.components import persistent_notification
from homeassistant.config import load_yaml_config_file
from homeassistant.exceptions import HomeAssistantError
import homeassistant.helpers.config_validation as cv
from homeassistant.util.yaml import dump
from .const import (
KEY_BANS_ENABLED, KEY_BANNED_IPS, KEY_LOGIN_THRESHOLD,
KEY_FAILED_LOGIN_ATTEMPTS)
from .util import get_real_ip
NOTIFICATION_ID_BAN = 'ip-ban'
IP_BANS_FILE = 'ip_bans.yaml'
ATTR_BANNED_AT = "banned_at"
SCHEMA_IP_BAN_ENTRY = vol.Schema({
vol.Optional('banned_at'): vol.Any(None, cv.datetime)
})
_LOGGER = logging.getLogger(__name__)
@asyncio.coroutine
def ban_middleware(app, handler):
"""IP Ban middleware."""
if not app[KEY_BANS_ENABLED]:
return handler
if KEY_BANNED_IPS not in app:
hass = app['hass']
app[KEY_BANNED_IPS] = yield from hass.loop.run_in_executor(
None, load_ip_bans_config, hass.config.path(IP_BANS_FILE))
@asyncio.coroutine
def ban_middleware_handler(request):
"""Verify if IP is not banned."""
ip_address_ = get_real_ip(request)
is_banned = any(ip_ban.ip_address == ip_address_
for ip_ban in request.app[KEY_BANNED_IPS])
if is_banned:
raise HTTPForbidden()
return handler(request)
return ban_middleware_handler
@asyncio.coroutine
def process_wrong_login(request):
"""Process a wrong login attempt."""
if (not request.app[KEY_BANS_ENABLED] or
request.app[KEY_LOGIN_THRESHOLD] < 1):
return
if KEY_FAILED_LOGIN_ATTEMPTS not in request.app:
request.app[KEY_FAILED_LOGIN_ATTEMPTS] = defaultdict(int)
remote_addr = get_real_ip(request)
request.app[KEY_FAILED_LOGIN_ATTEMPTS][remote_addr] += 1
if (request.app[KEY_FAILED_LOGIN_ATTEMPTS][remote_addr] >
request.app[KEY_LOGIN_THRESHOLD]):
new_ban = IpBan(remote_addr)
request.app[KEY_BANNED_IPS].append(new_ban)
hass = request.app['hass']
yield from hass.loop.run_in_executor(
None, update_ip_bans_config, hass.config.path(IP_BANS_FILE),
new_ban)
_LOGGER.warning('Banned IP %s for too many login attempts',
remote_addr)
persistent_notification.async_create(
hass,
'Too many login attempts from {}'.format(remote_addr),
'Banning IP address', NOTIFICATION_ID_BAN)
class IpBan(object):
"""Represents banned IP address."""
def __init__(self, ip_ban: str, banned_at: datetime=None) -> None:
"""Initializing Ip Ban object."""
self.ip_address = ip_address(ip_ban)
self.banned_at = banned_at or datetime.utcnow()
def load_ip_bans_config(path: str):
"""Loading list of banned IPs from config file."""
ip_list = []
try:
list_ = load_yaml_config_file(path)
except FileNotFoundError:
return []
except HomeAssistantError as err:
_LOGGER.error('Unable to load %s: %s', path, str(err))
return []
for ip_ban, ip_info in list_.items():
try:
ip_info = SCHEMA_IP_BAN_ENTRY(ip_info)
ip_list.append(IpBan(ip_ban, ip_info['banned_at']))
except vol.Invalid as err:
_LOGGER.error('Failed to load IP ban %s: %s', ip_info, err)
continue
return ip_list
def update_ip_bans_config(path: str, ip_ban: IpBan):
"""Update config file with new banned IP address."""
with open(path, 'a') as out:
ip_ = {str(ip_ban.ip_address): {
ATTR_BANNED_AT: ip_ban.banned_at.strftime("%Y-%m-%dT%H:%M:%S")
}}
out.write('\n')
out.write(dump(ip_))

View File

@ -0,0 +1,12 @@
"""HTTP specific constants."""
KEY_AUTHENTICATED = 'ha_authenticated'
KEY_USE_X_FORWARDED_FOR = 'ha_use_x_forwarded_for'
KEY_TRUSTED_NETWORKS = 'ha_trusted_networks'
KEY_REAL_IP = 'ha_real_ip'
KEY_BANS_ENABLED = 'ha_bans_enabled'
KEY_BANNED_IPS = 'ha_banned_ips'
KEY_FAILED_LOGIN_ATTEMPTS = 'ha_failed_login_attempts'
KEY_LOGIN_THRESHOLD = 'ha_login_treshold'
KEY_DEVELOPMENT = 'ha_development'
HTTP_HEADER_X_FORWARDED_FOR = 'X-Forwarded-For'

View File

@ -0,0 +1,93 @@
"""Static file handling for HTTP component."""
import asyncio
import mimetypes
import re
from aiohttp import hdrs
from aiohttp.file_sender import FileSender
from aiohttp.web_urldispatcher import StaticResource
from aiohttp.web_exceptions import HTTPNotModified
from .const import KEY_DEVELOPMENT
_FINGERPRINT = re.compile(r'^(.+)-[a-z0-9]{32}\.(\w+)$', re.IGNORECASE)
class GzipFileSender(FileSender):
"""FileSender class capable of sending gzip version if available."""
# pylint: disable=invalid-name
@asyncio.coroutine
def send(self, request, filepath):
"""Send filepath to client using request."""
gzip = False
if 'gzip' in request.headers[hdrs.ACCEPT_ENCODING]:
gzip_path = filepath.with_name(filepath.name + '.gz')
if gzip_path.is_file():
filepath = gzip_path
gzip = True
st = filepath.stat()
modsince = request.if_modified_since
if modsince is not None and st.st_mtime <= modsince.timestamp():
raise HTTPNotModified()
ct, encoding = mimetypes.guess_type(str(filepath))
if not ct:
ct = 'application/octet-stream'
resp = self._response_factory()
resp.content_type = ct
if encoding:
resp.headers[hdrs.CONTENT_ENCODING] = encoding
if gzip:
resp.headers[hdrs.VARY] = hdrs.ACCEPT_ENCODING
resp.last_modified = st.st_mtime
# CACHE HACK
if not request.app[KEY_DEVELOPMENT]:
cache_time = 31 * 86400 # = 1 month
resp.headers[hdrs.CACHE_CONTROL] = "public, max-age={}".format(
cache_time)
file_size = st.st_size
resp.content_length = file_size
with filepath.open('rb') as f:
yield from self._sendfile(request, resp, f, file_size)
return resp
GZIP_FILE_SENDER = GzipFileSender()
@asyncio.coroutine
def staticresource_middleware(app, handler):
"""Enhance StaticResourceHandler middleware.
Adds gzip encoding and fingerprinting matching.
"""
inst = getattr(handler, '__self__', None)
if not isinstance(inst, StaticResource):
return handler
# pylint: disable=protected-access
inst._file_sender = GZIP_FILE_SENDER
@asyncio.coroutine
def static_middleware_handler(request):
"""Strip out fingerprints from resource names."""
fingerprinted = _FINGERPRINT.match(request.match_info['filename'])
if fingerprinted:
request.match_info['filename'] = \
'{}.{}'.format(*fingerprinted.groups())
resp = yield from handler(request)
return resp
return static_middleware_handler

View File

@ -0,0 +1,25 @@
"""HTTP utilities."""
from ipaddress import ip_address
from .const import (
KEY_REAL_IP, KEY_USE_X_FORWARDED_FOR, HTTP_HEADER_X_FORWARDED_FOR)
def get_real_ip(request):
"""Get IP address of client."""
if KEY_REAL_IP in request:
return request[KEY_REAL_IP]
if (request.app[KEY_USE_X_FORWARDED_FOR] and
HTTP_HEADER_X_FORWARDED_FOR in request.headers):
request[KEY_REAL_IP] = ip_address(
request.headers.get(HTTP_HEADER_X_FORWARDED_FOR).split(',')[0])
else:
peername = request.transport.get_extra_info('peername')
if peername:
request[KEY_REAL_IP] = ip_address(peername[0])
else:
request[KEY_REAL_IP] = None
return request[KEY_REAL_IP]

View File

@ -250,11 +250,10 @@ def setup(hass, config):
discovery.load_platform(hass, "sensor", DOMAIN, {}, config)
hass.http.register_view(iOSIdentifyDeviceView(hass))
hass.http.register_view(iOSIdentifyDeviceView)
app_config = config.get(DOMAIN, {})
hass.http.register_view(iOSPushConfigView(hass,
app_config.get(CONF_PUSH, {})))
hass.http.register_view(iOSPushConfigView(app_config.get(CONF_PUSH, {})))
return True
@ -266,9 +265,8 @@ class iOSPushConfigView(HomeAssistantView):
url = "/api/ios/push"
name = "api:ios:push"
def __init__(self, hass, push_config):
def __init__(self, push_config):
"""Init the view."""
super().__init__(hass)
self.push_config = push_config
@callback
@ -283,10 +281,6 @@ class iOSIdentifyDeviceView(HomeAssistantView):
url = "/api/ios/identify"
name = "api:ios:identify"
def __init__(self, hass):
"""Init the view."""
super().__init__(hass)
@asyncio.coroutine
def post(self, request):
"""Handle the POST request for device identification."""

View File

@ -101,7 +101,7 @@ def setup(hass, config):
message = message.async_render()
async_log_entry(hass, name, message, domain, entity_id)
hass.http.register_view(LogbookView(hass, config))
hass.http.register_view(LogbookView(config))
register_built_in_panel(hass, 'logbook', 'Logbook',
'mdi:format-list-bulleted-type')
@ -118,9 +118,8 @@ class LogbookView(HomeAssistantView):
name = 'api:logbook'
extra_urls = ['/api/logbook/{datetime}']
def __init__(self, hass, config):
def __init__(self, config):
"""Initilalize the logbook view."""
super().__init__(hass)
self.config = config
@asyncio.coroutine
@ -146,7 +145,8 @@ class LogbookView(HomeAssistantView):
events = recorder.execute(query)
return _exclude_events(events, self.config)
events = yield from self.hass.loop.run_in_executor(None, get_results)
events = yield from request.app['hass'].loop.run_in_executor(
None, get_results)
return self.json(humanify(events))

View File

@ -17,7 +17,7 @@ from homeassistant.config import load_yaml_config_file
from homeassistant.helpers.entity import Entity
from homeassistant.helpers.entity_component import EntityComponent
from homeassistant.helpers.config_validation import PLATFORM_SCHEMA # noqa
from homeassistant.components.http import HomeAssistantView
from homeassistant.components.http import HomeAssistantView, KEY_AUTHENTICATED
import homeassistant.helpers.config_validation as cv
from homeassistant.util.async import run_coroutine_threadsafe
from homeassistant.const import (
@ -304,7 +304,7 @@ def setup(hass, config):
component = EntityComponent(
logging.getLogger(__name__), DOMAIN, hass, SCAN_INTERVAL)
hass.http.register_view(MediaPlayerImageView(hass, component.entities))
hass.http.register_view(MediaPlayerImageView(component.entities))
component.setup(config)
@ -736,9 +736,8 @@ class MediaPlayerImageView(HomeAssistantView):
url = "/api/media_player_proxy/{entity_id}"
name = "api:media_player:image"
def __init__(self, hass, entities):
def __init__(self, entities):
"""Initialize a media player view."""
super().__init__(hass)
self.entities = entities
@asyncio.coroutine
@ -748,14 +747,14 @@ class MediaPlayerImageView(HomeAssistantView):
if player is None:
return web.Response(status=404)
authenticated = (request.authenticated or
authenticated = (request[KEY_AUTHENTICATED] or
request.GET.get('token') == player.access_token)
if not authenticated:
return web.Response(status=401)
data, content_type = yield from _async_fetch_image(
self.hass, player.media_image_url)
request.app['hass'], player.media_image_url)
if data is None:
return web.Response(status=500)

View File

@ -107,8 +107,8 @@ def get_service(hass, config):
return None
hass.http.register_view(
HTML5PushRegistrationView(hass, registrations, json_path))
hass.http.register_view(HTML5PushCallbackView(hass, registrations))
HTML5PushRegistrationView(registrations, json_path))
hass.http.register_view(HTML5PushCallbackView(registrations))
gcm_api_key = config.get(ATTR_GCM_API_KEY)
gcm_sender_id = config.get(ATTR_GCM_SENDER_ID)
@ -168,9 +168,8 @@ class HTML5PushRegistrationView(HomeAssistantView):
url = '/api/notify.html5'
name = 'api:notify.html5'
def __init__(self, hass, registrations, json_path):
def __init__(self, registrations, json_path):
"""Init HTML5PushRegistrationView."""
super().__init__(hass)
self.registrations = registrations
self.json_path = json_path
@ -237,9 +236,8 @@ class HTML5PushCallbackView(HomeAssistantView):
url = '/api/notify.html5/callback'
name = 'api:notify.html5/callback'
def __init__(self, hass, registrations):
def __init__(self, registrations):
"""Init HTML5PushCallbackView."""
super().__init__(hass)
self.registrations = registrations
def decode_jwt(self, token):
@ -324,7 +322,7 @@ class HTML5PushCallbackView(HomeAssistantView):
event_name = '{}.{}'.format(NOTIFY_CALLBACK_EVENT,
event_payload[ATTR_TYPE])
self.hass.bus.fire(event_name, event_payload)
request.app['hass'].bus.fire(event_name, event_payload)
return self.json({'status': 'ok',
'event': event_payload[ATTR_TYPE]})

View File

@ -274,7 +274,7 @@ def setup_platform(hass, config, add_devices, discovery_info=None):
hass.http.register_redirect(FITBIT_AUTH_START, fitbit_auth_start_url)
hass.http.register_view(FitbitAuthCallbackView(
hass, config, add_devices, oauth))
config, add_devices, oauth))
request_oauth_completion(hass)
@ -286,9 +286,8 @@ class FitbitAuthCallbackView(HomeAssistantView):
url = '/auth/fitbit/callback'
name = 'auth:fitbit:callback'
def __init__(self, hass, config, add_devices, oauth):
def __init__(self, config, add_devices, oauth):
"""Initialize the OAuth callback view."""
super().__init__(hass)
self.config = config
self.add_devices = add_devices
self.oauth = oauth
@ -299,6 +298,7 @@ class FitbitAuthCallbackView(HomeAssistantView):
from oauthlib.oauth2.rfc6749.errors import MismatchingStateError
from oauthlib.oauth2.rfc6749.errors import MissingTokenError
hass = request.app['hass']
data = request.GET
response_message = """Fitbit has been successfully authorized!
@ -306,7 +306,7 @@ class FitbitAuthCallbackView(HomeAssistantView):
if data.get('code') is not None:
redirect_uri = '{}{}'.format(
self.hass.config.api.base_url, FITBIT_AUTH_CALLBACK_PATH)
hass.config.api.base_url, FITBIT_AUTH_CALLBACK_PATH)
try:
self.oauth.fetch_access_token(data.get('code'), redirect_uri)
@ -336,12 +336,11 @@ class FitbitAuthCallbackView(HomeAssistantView):
ATTR_CLIENT_ID: self.oauth.client_id,
ATTR_CLIENT_SECRET: self.oauth.client_secret
}
if not config_from_file(self.hass.config.path(FITBIT_CONFIG_FILE),
if not config_from_file(hass.config.path(FITBIT_CONFIG_FILE),
config_contents):
_LOGGER.error("Failed to save config file")
self.hass.async_add_job(setup_platform, self.hass, self.config,
self.add_devices)
hass.async_add_job(setup_platform, hass, self.config, self.add_devices)
return html_response

View File

@ -59,7 +59,7 @@ def setup_platform(hass, config, add_devices, discovery_info=None):
sensors = {}
hass.http.register_view(TorqueReceiveDataView(
hass, email, vehicle, sensors, add_devices))
email, vehicle, sensors, add_devices))
return True
@ -69,9 +69,8 @@ class TorqueReceiveDataView(HomeAssistantView):
url = API_PATH
name = 'api:torque'
def __init__(self, hass, email, vehicle, sensors, add_devices):
def __init__(self, email, vehicle, sensors, add_devices):
"""Initialize a Torque view."""
super().__init__(hass)
self.email = email
self.vehicle = vehicle
self.sensors = sensors
@ -80,6 +79,7 @@ class TorqueReceiveDataView(HomeAssistantView):
@callback
def get(self, request):
"""Handle Torque data request."""
hass = request.app['hass']
data = request.GET
if self.email is not None and self.email != data[SENSOR_EMAIL_FIELD]:
@ -108,7 +108,7 @@ class TorqueReceiveDataView(HomeAssistantView):
self.sensors[pid] = TorqueSensor(
ENTITY_NAME_FORMAT.format(self.vehicle, names[pid]),
units.get(pid, None))
self.hass.async_add_job(self.add_devices, [self.sensors[pid]])
hass.async_add_job(self.add_devices, [self.sensors[pid]])
return None

View File

@ -97,6 +97,7 @@ class NetioApiView(HomeAssistantView):
@callback
def get(self, request, host):
"""Request handler."""
hass = request.app['hass']
data = request.GET
states, consumptions, cumulated_consumptions, start_dates = \
[], [], [], []
@ -119,7 +120,7 @@ class NetioApiView(HomeAssistantView):
ndev.start_dates = start_dates
for dev in DEVICES[host].entities:
self.hass.async_add_job(dev.async_update_ha_state())
hass.async_add_job(dev.async_update_ha_state())
return self.json(True)

View File

@ -360,7 +360,6 @@ HTTP_HEADER_CONTENT_LENGTH = 'Content-Length'
HTTP_HEADER_CACHE_CONTROL = 'Cache-Control'
HTTP_HEADER_EXPIRES = 'Expires'
HTTP_HEADER_ORIGIN = 'Origin'
HTTP_HEADER_X_FORWARDED_FOR = 'X-Forwarded-For'
HTTP_HEADER_X_REQUESTED_WITH = 'X-Requested-With'
HTTP_HEADER_ACCEPT = 'Accept'
HTTP_HEADER_ACCESS_CONTROL_ALLOW_ORIGIN = 'Access-Control-Allow-Origin'

View File

@ -0,0 +1,17 @@
"""Logging utilities."""
import logging
class HideSensitiveDataFilter(logging.Filter):
"""Filter API password calls."""
def __init__(self, text):
"""Initialize sensitive data filter."""
super().__init__()
self.text = text
def filter(self, record):
"""Hide sensitive data in messages."""
record.msg = record.msg.replace(self.text, '*******')
return True

View File

@ -10,6 +10,8 @@ import logging
import threading
from contextlib import contextmanager
from aiohttp import web
from homeassistant import core as ha, loader
from homeassistant.bootstrap import (
setup_component, async_prepare_setup_component)
@ -22,6 +24,9 @@ from homeassistant.const import (
EVENT_STATE_CHANGED, EVENT_PLATFORM_DISCOVERED, ATTR_SERVICE,
ATTR_DISCOVERED, SERVER_PORT)
from homeassistant.components import sun, mqtt
from homeassistant.components.http.auth import auth_middleware
from homeassistant.components.http.const import (
KEY_USE_X_FORWARDED_FOR, KEY_BANS_ENABLED)
_TEST_INSTANCE_PORT = SERVER_PORT
_LOGGER = logging.getLogger(__name__)
@ -210,13 +215,23 @@ def mock_http_component(hass):
"""Store registered view."""
if isinstance(view, type):
# Instantiate the view, if needed
view = view(hass)
view = view()
hass.http.views[view.name] = view
hass.http.register_view = mock_register_view
def mock_http_component_app(hass):
"""Create an aiohttp.web.Application instance for testing."""
hass.http.api_password = None
app = web.Application(middlewares=[auth_middleware], loop=hass.loop)
app['hass'] = hass
app[KEY_USE_X_FORWARDED_FOR] = False
app[KEY_BANS_ENABLED] = False
return app
def mock_mqtt_component(hass):
"""Mock the MQTT component."""
with mock.patch('homeassistant.components.mqtt.MQTT') as mock_mqtt:

View File

@ -27,8 +27,8 @@ def test_fetching_url(aioclient_mock, hass, test_client):
resp = yield from client.get('/api/camera_proxy/camera.config_test')
assert aioclient_mock.call_count == 1
assert resp.status == 200
assert aioclient_mock.call_count == 1
body = yield from resp.text()
assert body == 'hello world'

View File

@ -0,0 +1 @@
"""Tests for the HTTP component."""

View File

@ -0,0 +1,169 @@
"""The tests for the Home Assistant HTTP component."""
# pylint: disable=protected-access
import logging
from ipaddress import ip_address, ip_network
from unittest.mock import patch
import requests
from homeassistant import bootstrap, const
import homeassistant.components.http as http
from homeassistant.components.http.const import (
KEY_TRUSTED_NETWORKS, KEY_USE_X_FORWARDED_FOR, HTTP_HEADER_X_FORWARDED_FOR)
from tests.common import get_test_instance_port, get_test_home_assistant
API_PASSWORD = 'test1234'
SERVER_PORT = get_test_instance_port()
HTTP_BASE = '127.0.0.1:{}'.format(SERVER_PORT)
HTTP_BASE_URL = 'http://{}'.format(HTTP_BASE)
HA_HEADERS = {
const.HTTP_HEADER_HA_AUTH: API_PASSWORD,
const.HTTP_HEADER_CONTENT_TYPE: const.CONTENT_TYPE_JSON,
}
# Don't add 127.0.0.1/::1 as trusted, as it may interfere with other test cases
TRUSTED_NETWORKS = ['192.0.2.0/24', '2001:DB8:ABCD::/48', '100.64.0.1',
'FD01:DB8::1']
TRUSTED_ADDRESSES = ['100.64.0.1', '192.0.2.100', 'FD01:DB8::1',
'2001:DB8:ABCD::1']
UNTRUSTED_ADDRESSES = ['198.51.100.1', '2001:DB8:FA1::1', '127.0.0.1', '::1']
hass = None
def _url(path=''):
"""Helper method to generate URLs."""
return HTTP_BASE_URL + path
# pylint: disable=invalid-name
def setUpModule():
"""Initialize a Home Assistant server."""
global hass
hass = get_test_home_assistant()
bootstrap.setup_component(
hass, http.DOMAIN, {
http.DOMAIN: {
http.CONF_API_PASSWORD: API_PASSWORD,
http.CONF_SERVER_PORT: SERVER_PORT,
}
}
)
bootstrap.setup_component(hass, 'api')
hass.http.app[KEY_TRUSTED_NETWORKS] = [
ip_network(trusted_network)
for trusted_network in TRUSTED_NETWORKS]
hass.start()
# pylint: disable=invalid-name
def tearDownModule():
"""Stop the Home Assistant server."""
hass.stop()
class TestHttp:
"""Test HTTP component."""
def test_access_denied_without_password(self):
"""Test access without password."""
req = requests.get(_url(const.URL_API))
assert req.status_code == 401
def test_access_denied_with_wrong_password_in_header(self):
"""Test access with wrong password."""
req = requests.get(
_url(const.URL_API),
headers={const.HTTP_HEADER_HA_AUTH: 'wrongpassword'})
assert req.status_code == 401
def test_access_denied_with_x_forwarded_for(self, caplog):
"""Test access denied through the X-Forwarded-For http header."""
hass.http.use_x_forwarded_for = True
for remote_addr in UNTRUSTED_ADDRESSES:
req = requests.get(_url(const.URL_API), headers={
HTTP_HEADER_X_FORWARDED_FOR: remote_addr})
assert req.status_code == 401, \
"{} shouldn't be trusted".format(remote_addr)
def test_access_denied_with_untrusted_ip(self, caplog):
"""Test access with an untrusted ip address."""
for remote_addr in UNTRUSTED_ADDRESSES:
with patch('homeassistant.components.http.'
'util.get_real_ip',
return_value=ip_address(remote_addr)):
req = requests.get(
_url(const.URL_API), params={'api_password': ''})
assert req.status_code == 401, \
"{} shouldn't be trusted".format(remote_addr)
def test_access_with_password_in_header(self, caplog):
"""Test access with password in URL."""
# Hide logging from requests package that we use to test logging
caplog.set_level(
logging.WARNING, logger='requests.packages.urllib3.connectionpool')
req = requests.get(
_url(const.URL_API),
headers={const.HTTP_HEADER_HA_AUTH: API_PASSWORD})
assert req.status_code == 200
logs = caplog.text
assert const.URL_API in logs
assert API_PASSWORD not in logs
def test_access_denied_with_wrong_password_in_url(self):
"""Test access with wrong password."""
req = requests.get(
_url(const.URL_API), params={'api_password': 'wrongpassword'})
assert req.status_code == 401
def test_access_with_password_in_url(self, caplog):
"""Test access with password in URL."""
# Hide logging from requests package that we use to test logging
caplog.set_level(
logging.WARNING, logger='requests.packages.urllib3.connectionpool')
req = requests.get(
_url(const.URL_API), params={'api_password': API_PASSWORD})
assert req.status_code == 200
logs = caplog.text
assert const.URL_API in logs
assert API_PASSWORD not in logs
def test_access_granted_with_x_forwarded_for(self, caplog):
"""Test access denied through the X-Forwarded-For http header."""
hass.http.app[KEY_USE_X_FORWARDED_FOR] = True
for remote_addr in TRUSTED_ADDRESSES:
req = requests.get(_url(const.URL_API), headers={
HTTP_HEADER_X_FORWARDED_FOR: remote_addr})
assert req.status_code == 200, \
"{} should be trusted".format(remote_addr)
def test_access_granted_with_trusted_ip(self, caplog):
"""Test access with trusted addresses."""
for remote_addr in TRUSTED_ADDRESSES:
with patch('homeassistant.components.http.'
'auth.get_real_ip',
return_value=ip_address(remote_addr)):
req = requests.get(
_url(const.URL_API), params={'api_password': ''})
assert req.status_code == 200, \
'{} should be trusted'.format(remote_addr)

View File

@ -0,0 +1,118 @@
"""The tests for the Home Assistant HTTP component."""
# pylint: disable=protected-access
from ipaddress import ip_address
from unittest.mock import patch, mock_open
import requests
from homeassistant import bootstrap, const
import homeassistant.components.http as http
from homeassistant.components.http.const import (
KEY_BANS_ENABLED, KEY_LOGIN_THRESHOLD, KEY_BANNED_IPS)
from homeassistant.components.http.ban import IpBan, IP_BANS_FILE
from tests.common import get_test_instance_port, get_test_home_assistant
API_PASSWORD = 'test1234'
SERVER_PORT = get_test_instance_port()
HTTP_BASE = '127.0.0.1:{}'.format(SERVER_PORT)
HTTP_BASE_URL = 'http://{}'.format(HTTP_BASE)
HA_HEADERS = {
const.HTTP_HEADER_HA_AUTH: API_PASSWORD,
const.HTTP_HEADER_CONTENT_TYPE: const.CONTENT_TYPE_JSON,
}
BANNED_IPS = ['200.201.202.203', '100.64.0.2']
hass = None
def _url(path=''):
"""Helper method to generate URLs."""
return HTTP_BASE_URL + path
# pylint: disable=invalid-name
def setUpModule():
"""Initialize a Home Assistant server."""
global hass
hass = get_test_home_assistant()
bootstrap.setup_component(
hass, http.DOMAIN, {
http.DOMAIN: {
http.CONF_API_PASSWORD: API_PASSWORD,
http.CONF_SERVER_PORT: SERVER_PORT,
}
}
)
bootstrap.setup_component(hass, 'api')
hass.http.app[KEY_BANNED_IPS] = [IpBan(banned_ip) for banned_ip
in BANNED_IPS]
hass.start()
# pylint: disable=invalid-name
def tearDownModule():
"""Stop the Home Assistant server."""
hass.stop()
class TestHttp:
"""Test HTTP component."""
def test_access_from_banned_ip(self):
"""Test accessing to server from banned IP. Both trusted and not."""
hass.http.app[KEY_BANS_ENABLED] = True
for remote_addr in BANNED_IPS:
with patch('homeassistant.components.http.'
'ban.get_real_ip',
return_value=ip_address(remote_addr)):
req = requests.get(
_url(const.URL_API))
assert req.status_code == 403
def test_access_from_banned_ip_when_ban_is_off(self):
"""Test accessing to server from banned IP when feature is off"""
hass.http.app[KEY_BANS_ENABLED] = False
for remote_addr in BANNED_IPS:
with patch('homeassistant.components.http.'
'ban.get_real_ip',
return_value=ip_address(remote_addr)):
req = requests.get(
_url(const.URL_API),
headers={const.HTTP_HEADER_HA_AUTH: API_PASSWORD})
assert req.status_code == 200
def test_ip_bans_file_creation(self):
"""Testing if banned IP file created"""
hass.http.app[KEY_BANS_ENABLED] = True
hass.http.app[KEY_LOGIN_THRESHOLD] = 1
m = mock_open()
def call_server():
with patch('homeassistant.components.http.'
'ban.get_real_ip',
return_value=ip_address("200.201.202.204")):
print("GETTING API")
return requests.get(
_url(const.URL_API),
headers={const.HTTP_HEADER_HA_AUTH: 'Wrong password'})
with patch('homeassistant.components.http.ban.open', m, create=True):
req = call_server()
assert req.status_code == 401
assert len(hass.http.app[KEY_BANNED_IPS]) == len(BANNED_IPS)
assert m.call_count == 0
req = call_server()
assert req.status_code == 401
assert len(hass.http.app[KEY_BANNED_IPS]) == len(BANNED_IPS) + 1
m.assert_called_once_with(hass.config.path(IP_BANS_FILE), 'a')
req = call_server()
assert req.status_code == 403
assert m.call_count == 1

View File

@ -0,0 +1,111 @@
"""The tests for the Home Assistant HTTP component."""
import requests
from homeassistant import bootstrap, const
import homeassistant.components.http as http
from tests.common import get_test_instance_port, get_test_home_assistant
API_PASSWORD = 'test1234'
SERVER_PORT = get_test_instance_port()
HTTP_BASE = '127.0.0.1:{}'.format(SERVER_PORT)
HTTP_BASE_URL = 'http://{}'.format(HTTP_BASE)
HA_HEADERS = {
const.HTTP_HEADER_HA_AUTH: API_PASSWORD,
const.HTTP_HEADER_CONTENT_TYPE: const.CONTENT_TYPE_JSON,
}
CORS_ORIGINS = [HTTP_BASE_URL, HTTP_BASE]
hass = None
def _url(path=''):
"""Helper method to generate URLs."""
return HTTP_BASE_URL + path
# pylint: disable=invalid-name
def setUpModule():
"""Initialize a Home Assistant server."""
global hass
hass = get_test_home_assistant()
bootstrap.setup_component(
hass, http.DOMAIN, {
http.DOMAIN: {
http.CONF_API_PASSWORD: API_PASSWORD,
http.CONF_SERVER_PORT: SERVER_PORT,
http.CONF_CORS_ORIGINS: CORS_ORIGINS,
}
}
)
bootstrap.setup_component(hass, 'api')
hass.start()
# pylint: disable=invalid-name
def tearDownModule():
"""Stop the Home Assistant server."""
hass.stop()
class TestHttp:
"""Test HTTP component."""
def test_cors_allowed_with_password_in_url(self):
"""Test cross origin resource sharing with password in url."""
req = requests.get(_url(const.URL_API),
params={'api_password': API_PASSWORD},
headers={const.HTTP_HEADER_ORIGIN: HTTP_BASE_URL})
allow_origin = const.HTTP_HEADER_ACCESS_CONTROL_ALLOW_ORIGIN
assert req.status_code == 200
assert req.headers.get(allow_origin) == HTTP_BASE_URL
def test_cors_allowed_with_password_in_header(self):
"""Test cross origin resource sharing with password in header."""
headers = {
const.HTTP_HEADER_HA_AUTH: API_PASSWORD,
const.HTTP_HEADER_ORIGIN: HTTP_BASE_URL
}
req = requests.get(_url(const.URL_API), headers=headers)
allow_origin = const.HTTP_HEADER_ACCESS_CONTROL_ALLOW_ORIGIN
assert req.status_code == 200
assert req.headers.get(allow_origin) == HTTP_BASE_URL
def test_cors_denied_without_origin_header(self):
"""Test cross origin resource sharing with password in header."""
headers = {
const.HTTP_HEADER_HA_AUTH: API_PASSWORD
}
req = requests.get(_url(const.URL_API), headers=headers)
allow_origin = const.HTTP_HEADER_ACCESS_CONTROL_ALLOW_ORIGIN
allow_headers = const.HTTP_HEADER_ACCESS_CONTROL_ALLOW_HEADERS
assert req.status_code == 200
assert allow_origin not in req.headers
assert allow_headers not in req.headers
def test_cors_preflight_allowed(self):
"""Test cross origin resource sharing preflight (OPTIONS) request."""
headers = {
const.HTTP_HEADER_ORIGIN: HTTP_BASE_URL,
'Access-Control-Request-Method': 'GET',
'Access-Control-Request-Headers': 'x-ha-access'
}
req = requests.options(_url(const.URL_API), headers=headers)
allow_origin = const.HTTP_HEADER_ACCESS_CONTROL_ALLOW_ORIGIN
allow_headers = const.HTTP_HEADER_ACCESS_CONTROL_ALLOW_HEADERS
assert req.status_code == 200
assert req.headers.get(allow_origin) == HTTP_BASE_URL
assert req.headers.get(allow_headers) == \
const.HTTP_HEADER_HA_AUTH.upper()

View File

@ -3,10 +3,10 @@ import asyncio
import json
from unittest.mock import patch, MagicMock, mock_open
from aiohttp import web
from homeassistant.components.notify import html5
from tests.common import mock_http_component_app
SUBSCRIPTION_1 = {
'browser': 'chrome',
'subscription': {
@ -121,7 +121,8 @@ class TestHtml5Notify(object):
assert view.json_path == hass.config.path.return_value
assert view.registrations == {}
app = web.Application(loop=loop)
hass.loop = loop
app = mock_http_component_app(hass)
view.register(app.router)
client = yield from test_client(app)
hass.http.is_banned_ip.return_value = False
@ -153,7 +154,8 @@ class TestHtml5Notify(object):
view = hass.mock_calls[1][1][0]
app = web.Application(loop=loop)
hass.loop = loop
app = mock_http_component_app(hass)
view.register(app.router)
client = yield from test_client(app)
hass.http.is_banned_ip.return_value = False
@ -208,7 +210,8 @@ class TestHtml5Notify(object):
assert view.json_path == hass.config.path.return_value
assert view.registrations == config
app = web.Application(loop=loop)
hass.loop = loop
app = mock_http_component_app(hass)
view.register(app.router)
client = yield from test_client(app)
hass.http.is_banned_ip.return_value = False
@ -253,7 +256,8 @@ class TestHtml5Notify(object):
assert view.json_path == hass.config.path.return_value
assert view.registrations == config
app = web.Application(loop=loop)
hass.loop = loop
app = mock_http_component_app(hass)
view.register(app.router)
client = yield from test_client(app)
hass.http.is_banned_ip.return_value = False
@ -296,7 +300,8 @@ class TestHtml5Notify(object):
assert view.json_path == hass.config.path.return_value
assert view.registrations == config
app = web.Application(loop=loop)
hass.loop = loop
app = mock_http_component_app(hass)
view.register(app.router)
client = yield from test_client(app)
hass.http.is_banned_ip.return_value = False
@ -331,7 +336,8 @@ class TestHtml5Notify(object):
view = hass.mock_calls[2][1][0]
app = web.Application(loop=loop)
hass.loop = loop
app = mock_http_component_app(hass)
view.register(app.router)
client = yield from test_client(app)
hass.http.is_banned_ip.return_value = False
@ -387,7 +393,8 @@ class TestHtml5Notify(object):
bearer_token = "Bearer {}".format(push_payload['data']['jwt'])
app = web.Application(loop=loop)
hass.loop = loop
app = mock_http_component_app(hass)
view.register(app.router)
client = yield from test_client(app)
hass.http.is_banned_ip.return_value = False

View File

@ -6,7 +6,7 @@ import unittest
import requests
import homeassistant.bootstrap as bootstrap
from homeassistant.components import frontend, http
from homeassistant.components import http
from homeassistant.const import HTTP_HEADER_HA_AUTH
from tests.common import get_test_instance_port, get_test_home_assistant
@ -45,7 +45,6 @@ def setUpModule():
def tearDownModule():
"""Stop everything that was started."""
hass.stop()
frontend.PANELS = {}
class TestFrontend(unittest.TestCase):

View File

@ -1,285 +0,0 @@
"""The tests for the Home Assistant HTTP component."""
# pylint: disable=protected-access
import logging
from ipaddress import ip_network
from unittest.mock import patch, mock_open
import requests
from homeassistant import bootstrap, const
import homeassistant.components.http as http
from tests.common import get_test_instance_port, get_test_home_assistant
API_PASSWORD = 'test1234'
SERVER_PORT = get_test_instance_port()
HTTP_BASE = '127.0.0.1:{}'.format(SERVER_PORT)
HTTP_BASE_URL = 'http://{}'.format(HTTP_BASE)
HA_HEADERS = {
const.HTTP_HEADER_HA_AUTH: API_PASSWORD,
const.HTTP_HEADER_CONTENT_TYPE: const.CONTENT_TYPE_JSON,
}
# Don't add 127.0.0.1/::1 as trusted, as it may interfere with other test cases
TRUSTED_NETWORKS = ['192.0.2.0/24', '2001:DB8:ABCD::/48', '100.64.0.1',
'FD01:DB8::1']
TRUSTED_ADDRESSES = ['100.64.0.1', '192.0.2.100', 'FD01:DB8::1',
'2001:DB8:ABCD::1']
UNTRUSTED_ADDRESSES = ['198.51.100.1', '2001:DB8:FA1::1', '127.0.0.1', '::1']
BANNED_IPS = ['200.201.202.203', '100.64.0.1']
CORS_ORIGINS = [HTTP_BASE_URL, HTTP_BASE]
hass = None
def _url(path=''):
"""Helper method to generate URLs."""
return HTTP_BASE_URL + path
# pylint: disable=invalid-name
def setUpModule():
"""Initialize a Home Assistant server."""
global hass
hass = get_test_home_assistant()
hass.bus.listen('test_event', lambda _: _)
hass.states.set('test.test', 'a_state')
bootstrap.setup_component(
hass, http.DOMAIN, {
http.DOMAIN: {
http.CONF_API_PASSWORD: API_PASSWORD,
http.CONF_SERVER_PORT: SERVER_PORT,
http.CONF_CORS_ORIGINS: CORS_ORIGINS,
}
}
)
bootstrap.setup_component(hass, 'api')
hass.http.trusted_networks = [
ip_network(trusted_network)
for trusted_network in TRUSTED_NETWORKS]
hass.http.ip_bans = [http.IpBan(banned_ip)
for banned_ip in BANNED_IPS]
hass.start()
# pylint: disable=invalid-name
def tearDownModule():
"""Stop the Home Assistant server."""
hass.stop()
class TestHttp:
"""Test HTTP component."""
def test_access_denied_without_password(self):
"""Test access without password."""
req = requests.get(_url(const.URL_API))
assert req.status_code == 401
def test_access_denied_with_wrong_password_in_header(self):
"""Test access with wrong password."""
req = requests.get(
_url(const.URL_API),
headers={const.HTTP_HEADER_HA_AUTH: 'wrongpassword'})
assert req.status_code == 401
def test_access_denied_with_x_forwarded_for(self, caplog):
"""Test access denied through the X-Forwarded-For http header."""
hass.http.use_x_forwarded_for = True
for remote_addr in UNTRUSTED_ADDRESSES:
req = requests.get(_url(const.URL_API), headers={
const.HTTP_HEADER_X_FORWARDED_FOR: remote_addr})
assert req.status_code == 401, \
"{} shouldn't be trusted".format(remote_addr)
def test_access_denied_with_untrusted_ip(self, caplog):
"""Test access with an untrusted ip address."""
for remote_addr in UNTRUSTED_ADDRESSES:
with patch('homeassistant.components.http.'
'HomeAssistantWSGI.get_real_ip',
return_value=remote_addr):
req = requests.get(
_url(const.URL_API), params={'api_password': ''})
assert req.status_code == 401, \
"{} shouldn't be trusted".format(remote_addr)
def test_access_with_password_in_header(self, caplog):
"""Test access with password in URL."""
# Hide logging from requests package that we use to test logging
caplog.set_level(
logging.WARNING, logger='requests.packages.urllib3.connectionpool')
req = requests.get(
_url(const.URL_API),
headers={const.HTTP_HEADER_HA_AUTH: API_PASSWORD})
assert req.status_code == 200
logs = caplog.text
# assert const.URL_API in logs
assert API_PASSWORD not in logs
def test_access_denied_with_wrong_password_in_url(self):
"""Test access with wrong password."""
req = requests.get(
_url(const.URL_API), params={'api_password': 'wrongpassword'})
assert req.status_code == 401
def test_access_with_password_in_url(self, caplog):
"""Test access with password in URL."""
# Hide logging from requests package that we use to test logging
caplog.set_level(
logging.WARNING, logger='requests.packages.urllib3.connectionpool')
req = requests.get(
_url(const.URL_API), params={'api_password': API_PASSWORD})
assert req.status_code == 200
logs = caplog.text
# assert const.URL_API in logs
assert API_PASSWORD not in logs
def test_access_granted_with_x_forwarded_for(self, caplog):
"""Test access denied through the X-Forwarded-For http header."""
hass.http.use_x_forwarded_for = True
for remote_addr in TRUSTED_ADDRESSES:
req = requests.get(_url(const.URL_API), headers={
const.HTTP_HEADER_X_FORWARDED_FOR: remote_addr})
assert req.status_code == 200, \
"{} should be trusted".format(remote_addr)
def test_access_granted_with_trusted_ip(self, caplog):
"""Test access with trusted addresses."""
for remote_addr in TRUSTED_ADDRESSES:
with patch('homeassistant.components.http.'
'HomeAssistantWSGI.get_real_ip',
return_value=remote_addr):
req = requests.get(
_url(const.URL_API), params={'api_password': ''})
assert req.status_code == 200, \
'{} should be trusted'.format(remote_addr)
def test_cors_allowed_with_password_in_url(self):
"""Test cross origin resource sharing with password in url."""
req = requests.get(_url(const.URL_API),
params={'api_password': API_PASSWORD},
headers={const.HTTP_HEADER_ORIGIN: HTTP_BASE_URL})
allow_origin = const.HTTP_HEADER_ACCESS_CONTROL_ALLOW_ORIGIN
assert req.status_code == 200
assert req.headers.get(allow_origin) == HTTP_BASE_URL
def test_cors_allowed_with_password_in_header(self):
"""Test cross origin resource sharing with password in header."""
headers = {
const.HTTP_HEADER_HA_AUTH: API_PASSWORD,
const.HTTP_HEADER_ORIGIN: HTTP_BASE_URL
}
req = requests.get(_url(const.URL_API), headers=headers)
allow_origin = const.HTTP_HEADER_ACCESS_CONTROL_ALLOW_ORIGIN
assert req.status_code == 200
assert req.headers.get(allow_origin) == HTTP_BASE_URL
def test_cors_denied_without_origin_header(self):
"""Test cross origin resource sharing with password in header."""
headers = {
const.HTTP_HEADER_HA_AUTH: API_PASSWORD
}
req = requests.get(_url(const.URL_API), headers=headers)
allow_origin = const.HTTP_HEADER_ACCESS_CONTROL_ALLOW_ORIGIN
allow_headers = const.HTTP_HEADER_ACCESS_CONTROL_ALLOW_HEADERS
assert req.status_code == 200
assert allow_origin not in req.headers
assert allow_headers not in req.headers
def test_cors_preflight_allowed(self):
"""Test cross origin resource sharing preflight (OPTIONS) request."""
headers = {
const.HTTP_HEADER_ORIGIN: HTTP_BASE_URL,
'Access-Control-Request-Method': 'GET',
'Access-Control-Request-Headers': 'x-ha-access'
}
req = requests.options(_url(const.URL_API), headers=headers)
allow_origin = const.HTTP_HEADER_ACCESS_CONTROL_ALLOW_ORIGIN
allow_headers = const.HTTP_HEADER_ACCESS_CONTROL_ALLOW_HEADERS
assert req.status_code == 200
assert req.headers.get(allow_origin) == HTTP_BASE_URL
assert req.headers.get(allow_headers) == \
const.HTTP_HEADER_HA_AUTH.upper()
def test_access_from_banned_ip(self):
"""Test accessing to server from banned IP. Both trusted and not."""
hass.http.is_ban_enabled = True
for remote_addr in BANNED_IPS:
with patch('homeassistant.components.http.'
'HomeAssistantWSGI.get_real_ip',
return_value=remote_addr):
req = requests.get(
_url(const.URL_API))
assert req.status_code == 403
def test_access_from_banned_ip_when_ban_is_off(self):
"""Test accessing to server from banned IP when feature is off"""
hass.http.is_ban_enabled = False
for remote_addr in BANNED_IPS:
with patch('homeassistant.components.http.'
'HomeAssistantWSGI.get_real_ip',
return_value=remote_addr):
req = requests.get(
_url(const.URL_API),
headers={const.HTTP_HEADER_HA_AUTH: API_PASSWORD})
assert req.status_code == 200
def test_ip_bans_file_creation(self):
"""Testing if banned IP file created"""
hass.http.is_ban_enabled = True
hass.http.login_threshold = 1
m = mock_open()
def call_server():
with patch('homeassistant.components.http.'
'HomeAssistantWSGI.get_real_ip',
return_value="200.201.202.204"):
return requests.get(
_url(const.URL_API),
headers={const.HTTP_HEADER_HA_AUTH: 'Wrong password'})
with patch('homeassistant.components.http.open', m, create=True):
req = call_server()
assert req.status_code == 401
assert len(hass.http.ip_bans) == len(BANNED_IPS)
assert m.call_count == 0
req = call_server()
assert req.status_code == 401
assert len(hass.http.ip_bans) == len(BANNED_IPS) + 1
m.assert_called_once_with(hass.config.path(http.IP_BANS), 'a')
req = call_server()
assert req.status_code == 403
assert m.call_count == 1

View File

@ -165,7 +165,15 @@ class TestCheckConfig(unittest.TestCase):
self.assertDictEqual({
'components': {'http': {'api_password': 'abc123',
'cors_allowed_origins': [],
'development': '0',
'ip_ban_enabled': True,
'login_attempts_threshold': -1,
'server_host': '0.0.0.0',
'server_port': 8123,
'ssl_certificate': None,
'ssl_key': None,
'trusted_networks': [],
'use_x_forwarded_for': False}},
'except': {},
'secret_cache': {secrets_path: {'http_pw': 'abc123'}},