diff --git a/homeassistant/components/universal/media_player.py b/homeassistant/components/universal/media_player.py index 2a5fcee34dc3..e4891dca68af 100644 --- a/homeassistant/components/universal/media_player.py +++ b/homeassistant/components/universal/media_player.py @@ -1,9 +1,14 @@ """Combination of multiple media players for a universal controller.""" from copy import copy +from typing import Optional import voluptuous as vol -from homeassistant.components.media_player import PLATFORM_SCHEMA, MediaPlayerEntity +from homeassistant.components.media_player import ( + DEVICE_CLASSES_SCHEMA, + PLATFORM_SCHEMA, + MediaPlayerEntity, +) from homeassistant.components.media_player.const import ( ATTR_APP_ID, ATTR_APP_NAME, @@ -56,6 +61,7 @@ from homeassistant.const import ( ATTR_ENTITY_ID, ATTR_ENTITY_PICTURE, ATTR_SUPPORTED_FEATURES, + CONF_DEVICE_CLASS, CONF_NAME, CONF_STATE, CONF_STATE_TEMPLATE, @@ -109,6 +115,7 @@ PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend( vol.Optional(CONF_ATTRS, default={}): vol.Or( cv.ensure_list(ATTRS_SCHEMA), ATTRS_SCHEMA ), + vol.Optional(CONF_DEVICE_CLASS): DEVICE_CLASSES_SCHEMA, vol.Optional(CONF_STATE_TEMPLATE): cv.template, }, extra=vol.REMOVE_EXTRA, @@ -126,6 +133,7 @@ async def async_setup_platform(hass, config, async_add_entities, discovery_info= config.get(CONF_CHILDREN), config.get(CONF_COMMANDS), config.get(CONF_ATTRS), + config.get(CONF_DEVICE_CLASS), config.get(CONF_STATE_TEMPLATE), ) @@ -135,7 +143,16 @@ async def async_setup_platform(hass, config, async_add_entities, discovery_info= class UniversalMediaPlayer(MediaPlayerEntity): """Representation of an universal media player.""" - def __init__(self, hass, name, children, commands, attributes, state_template=None): + def __init__( + self, + hass, + name, + children, + commands, + attributes, + device_class=None, + state_template=None, + ): """Initialize the Universal media device.""" self.hass = hass self._name = name @@ -150,6 +167,7 @@ class UniversalMediaPlayer(MediaPlayerEntity): self._child_state = None self._state_template_result = None self._state_template = state_template + self._device_class = device_class async def async_added_to_hass(self): """Subscribe to children and template state changes.""" @@ -255,6 +273,11 @@ class UniversalMediaPlayer(MediaPlayerEntity): """No polling needed.""" return False + @property + def device_class(self) -> Optional[str]: + """Return the class of this device.""" + return self._device_class + @property def master_state(self): """Return the master state for entity or None.""" diff --git a/tests/components/universal/test_media_player.py b/tests/components/universal/test_media_player.py index 75cf029af40d..8d8bc80234e2 100644 --- a/tests/components/universal/test_media_player.py +++ b/tests/components/universal/test_media_player.py @@ -872,6 +872,25 @@ async def test_state_template(hass): assert hass.states.get("media_player.tv").state == STATE_OFF +async def test_device_class(hass): + """Test device_class property.""" + hass.states.async_set("sensor.test_sensor", "on") + + await async_setup_component( + hass, + "media_player", + { + "media_player": { + "platform": "universal", + "name": "tv", + "device_class": "tv", + } + }, + ) + await hass.async_block_till_done() + assert hass.states.get("media_player.tv").attributes["device_class"] == "tv" + + async def test_invalid_state_template(hass): """Test invalid state template sets state to None.""" hass.states.async_set("sensor.test_sensor", "on") @@ -1001,6 +1020,9 @@ async def test_reload(hass): assert hass.states.get("media_player.tv") is None assert hass.states.get("media_player.master_bed_tv").state == "on" assert hass.states.get("media_player.master_bed_tv").attributes["source"] == "act2" + assert ( + "device_class" not in hass.states.get("media_player.master_bed_tv").attributes + ) def _get_fixtures_base_path():