From 90de8fff5e45a857594d3e7fd8d91a90f73917b2 Mon Sep 17 00:00:00 2001 From: Allen Porter Date: Mon, 29 Apr 2024 00:56:35 +0000 Subject: [PATCH 1/9] Add Google Gen AI Conversation Agent Entity --- .../__init__.py | 134 +------ .../conversation.py | 162 ++++++++ .../manifest.json | 1 + .../conftest.py | 1 + .../snapshots/test_conversation.ambr | 169 +++++++++ .../test_conversation.py | 352 ++++++++++++++++++ .../test_init.py | 148 +------- 7 files changed, 696 insertions(+), 271 deletions(-) create mode 100644 homeassistant/components/google_generative_ai_conversation/conversation.py create mode 100644 tests/components/google_generative_ai_conversation/snapshots/test_conversation.ambr create mode 100644 tests/components/google_generative_ai_conversation/test_conversation.py diff --git a/homeassistant/components/google_generative_ai_conversation/__init__.py b/homeassistant/components/google_generative_ai_conversation/__init__.py index e956c288b53..a828c316544 100644 --- a/homeassistant/components/google_generative_ai_conversation/__init__.py +++ b/homeassistant/components/google_generative_ai_conversation/__init__.py @@ -6,52 +6,32 @@ from functools import partial import logging import mimetypes from pathlib import Path -from typing import Literal from google.api_core.exceptions import ClientError import google.generativeai as genai import google.generativeai.types as genai_types import voluptuous as vol -from homeassistant.components import conversation from homeassistant.config_entries import ConfigEntry -from homeassistant.const import CONF_API_KEY, MATCH_ALL +from homeassistant.const import CONF_API_KEY, Platform from homeassistant.core import ( HomeAssistant, ServiceCall, ServiceResponse, SupportsResponse, ) -from homeassistant.exceptions import ( - ConfigEntryNotReady, - HomeAssistantError, - TemplateError, -) -from homeassistant.helpers import config_validation as cv, intent, template +from homeassistant.exceptions import ConfigEntryNotReady, HomeAssistantError +from homeassistant.helpers import config_validation as cv from homeassistant.helpers.typing import ConfigType -from homeassistant.util import ulid -from .const import ( - CONF_CHAT_MODEL, - CONF_MAX_TOKENS, - CONF_PROMPT, - CONF_TEMPERATURE, - CONF_TOP_K, - CONF_TOP_P, - DEFAULT_CHAT_MODEL, - DEFAULT_MAX_TOKENS, - DEFAULT_PROMPT, - DEFAULT_TEMPERATURE, - DEFAULT_TOP_K, - DEFAULT_TOP_P, - DOMAIN, -) +from .const import CONF_CHAT_MODEL, CONF_PROMPT, DEFAULT_CHAT_MODEL, DOMAIN _LOGGER = logging.getLogger(__name__) SERVICE_GENERATE_CONTENT = "generate_content" CONF_IMAGE_FILENAME = "image_filename" CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN) +PLATFORMS = (Platform.CONVERSATION,) async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: @@ -130,109 +110,15 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: return False raise ConfigEntryNotReady(err) from err - conversation.async_set_agent(hass, entry, GoogleGenerativeAIAgent(hass, entry)) + await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS) + return True async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: """Unload GoogleGenerativeAI.""" + if not await hass.config_entries.async_unload_platforms(entry, PLATFORMS): + return False + genai.configure(api_key=None) - conversation.async_unset_agent(hass, entry) return True - - -class GoogleGenerativeAIAgent(conversation.AbstractConversationAgent): - """Google Generative AI conversation agent.""" - - def __init__(self, hass: HomeAssistant, entry: ConfigEntry) -> None: - """Initialize the agent.""" - self.hass = hass - self.entry = entry - self.history: dict[str, list[genai_types.ContentType]] = {} - - @property - def supported_languages(self) -> list[str] | Literal["*"]: - """Return a list of supported languages.""" - return MATCH_ALL - - async def async_process( - self, user_input: conversation.ConversationInput - ) -> conversation.ConversationResult: - """Process a sentence.""" - raw_prompt = self.entry.options.get(CONF_PROMPT, DEFAULT_PROMPT) - model = genai.GenerativeModel( - model_name=self.entry.options.get(CONF_CHAT_MODEL, DEFAULT_CHAT_MODEL), - generation_config={ - "temperature": self.entry.options.get( - CONF_TEMPERATURE, DEFAULT_TEMPERATURE - ), - "top_p": self.entry.options.get(CONF_TOP_P, DEFAULT_TOP_P), - "top_k": self.entry.options.get(CONF_TOP_K, DEFAULT_TOP_K), - "max_output_tokens": self.entry.options.get( - CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS - ), - }, - ) - _LOGGER.debug("Model: %s", model) - - if user_input.conversation_id in self.history: - conversation_id = user_input.conversation_id - messages = self.history[conversation_id] - else: - conversation_id = ulid.ulid_now() - messages = [{}, {}] - - try: - prompt = self._async_generate_prompt(raw_prompt) - except TemplateError as err: - _LOGGER.error("Error rendering prompt: %s", err) - intent_response = intent.IntentResponse(language=user_input.language) - intent_response.async_set_error( - intent.IntentResponseErrorCode.UNKNOWN, - f"Sorry, I had a problem with my template: {err}", - ) - return conversation.ConversationResult( - response=intent_response, conversation_id=conversation_id - ) - - messages[0] = {"role": "user", "parts": prompt} - messages[1] = {"role": "model", "parts": "Ok"} - - _LOGGER.debug("Input: '%s' with history: %s", user_input.text, messages) - - chat = model.start_chat(history=messages) - try: - chat_response = await chat.send_message_async(user_input.text) - except ( - ClientError, - ValueError, - genai_types.BlockedPromptException, - genai_types.StopCandidateException, - ) as err: - _LOGGER.error("Error sending message: %s", err) - intent_response = intent.IntentResponse(language=user_input.language) - intent_response.async_set_error( - intent.IntentResponseErrorCode.UNKNOWN, - f"Sorry, I had a problem talking to Google Generative AI: {err}", - ) - return conversation.ConversationResult( - response=intent_response, conversation_id=conversation_id - ) - - _LOGGER.debug("Response: %s", chat_response.parts) - self.history[conversation_id] = chat.history - - intent_response = intent.IntentResponse(language=user_input.language) - intent_response.async_set_speech(chat_response.text) - return conversation.ConversationResult( - response=intent_response, conversation_id=conversation_id - ) - - def _async_generate_prompt(self, raw_prompt: str) -> str: - """Generate a prompt for the user.""" - return template.Template(raw_prompt, self.hass).async_render( - { - "ha_name": self.hass.config.location_name, - }, - parse_result=False, - ) diff --git a/homeassistant/components/google_generative_ai_conversation/conversation.py b/homeassistant/components/google_generative_ai_conversation/conversation.py new file mode 100644 index 00000000000..673bd152613 --- /dev/null +++ b/homeassistant/components/google_generative_ai_conversation/conversation.py @@ -0,0 +1,162 @@ +"""Conversation support for the Google Generative AI Conversation integration.""" + +from __future__ import annotations + +import logging +from typing import Literal + +from google.api_core.exceptions import ClientError +import google.generativeai as genai +import google.generativeai.types as genai_types + +from homeassistant.components import assist_pipeline, conversation +from homeassistant.config_entries import ConfigEntry +from homeassistant.const import MATCH_ALL +from homeassistant.core import HomeAssistant +from homeassistant.exceptions import TemplateError +from homeassistant.helpers import intent, template +from homeassistant.helpers.entity_platform import AddEntitiesCallback +from homeassistant.util import ulid + +from .const import ( + CONF_CHAT_MODEL, + CONF_MAX_TOKENS, + CONF_PROMPT, + CONF_TEMPERATURE, + CONF_TOP_K, + CONF_TOP_P, + DEFAULT_CHAT_MODEL, + DEFAULT_MAX_TOKENS, + DEFAULT_PROMPT, + DEFAULT_TEMPERATURE, + DEFAULT_TOP_K, + DEFAULT_TOP_P, +) + +_LOGGER = logging.getLogger(__name__) + + +async def async_setup_entry( + hass: HomeAssistant, + config_entry: ConfigEntry, + async_add_entities: AddEntitiesCallback, +) -> None: + """Set up conversation entities.""" + agent = GoogleGenerativeAIAgent(hass, config_entry) + async_add_entities([agent]) + + +class GoogleGenerativeAIAgent( + conversation.ConversationEntity, conversation.AbstractConversationAgent +): + """Google Generative AI conversation agent.""" + + _attr_has_entity_name = True + + def __init__(self, hass: HomeAssistant, entry: ConfigEntry) -> None: + """Initialize the agent.""" + self.hass = hass + self.entry = entry + self.history: dict[str, list[genai_types.ContentType]] = {} + self._attr_name = entry.title + self._attr_unique_id = entry.entry_id + + async def async_added_to_hass(self) -> None: + """When entity is added to Home Assistant.""" + await super().async_added_to_hass() + assist_pipeline.async_migrate_engine( + self.hass, "conversation", self.entry.entry_id, self.entity_id + ) + conversation.async_set_agent(self.hass, self.entry, self) + + async def async_will_remove_from_hass(self) -> None: + """When entity will be removed from Home Assistant.""" + conversation.async_unset_agent(self.hass, self.entry) + await super().async_will_remove_from_hass() + + @property + def supported_languages(self) -> list[str] | Literal["*"]: + """Return a list of supported languages.""" + return MATCH_ALL + + async def async_process( + self, user_input: conversation.ConversationInput + ) -> conversation.ConversationResult: + """Process a sentence.""" + raw_prompt = self.entry.options.get(CONF_PROMPT, DEFAULT_PROMPT) + model = genai.GenerativeModel( + model_name=self.entry.options.get(CONF_CHAT_MODEL, DEFAULT_CHAT_MODEL), + generation_config={ + "temperature": self.entry.options.get( + CONF_TEMPERATURE, DEFAULT_TEMPERATURE + ), + "top_p": self.entry.options.get(CONF_TOP_P, DEFAULT_TOP_P), + "top_k": self.entry.options.get(CONF_TOP_K, DEFAULT_TOP_K), + "max_output_tokens": self.entry.options.get( + CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS + ), + }, + ) + _LOGGER.debug("Model: %s", model) + + if user_input.conversation_id in self.history: + conversation_id = user_input.conversation_id + messages = self.history[conversation_id] + else: + conversation_id = ulid.ulid_now() + messages = [{}, {}] + + try: + prompt = self._async_generate_prompt(raw_prompt) + except TemplateError as err: + _LOGGER.error("Error rendering prompt: %s", err) + intent_response = intent.IntentResponse(language=user_input.language) + intent_response.async_set_error( + intent.IntentResponseErrorCode.UNKNOWN, + f"Sorry, I had a problem with my template: {err}", + ) + return conversation.ConversationResult( + response=intent_response, conversation_id=conversation_id + ) + + messages[0] = {"role": "user", "parts": prompt} + messages[1] = {"role": "model", "parts": "Ok"} + + _LOGGER.debug("Input: '%s' with history: %s", user_input.text, messages) + + chat = model.start_chat(history=messages) + try: + chat_response = await chat.send_message_async(user_input.text) + except ( + ClientError, + ValueError, + genai_types.BlockedPromptException, + genai_types.StopCandidateException, + ) as err: + _LOGGER.error("Error sending message: %s", err) + intent_response = intent.IntentResponse(language=user_input.language) + intent_response.async_set_error( + intent.IntentResponseErrorCode.UNKNOWN, + f"Sorry, I had a problem talking to Google Generative AI: {err}", + ) + return conversation.ConversationResult( + response=intent_response, conversation_id=conversation_id + ) + + _LOGGER.debug("Response: %s", chat_response.parts) + self.history[conversation_id] = chat.history + + intent_response = intent.IntentResponse(language=user_input.language) + intent_response.async_set_speech(chat_response.text) + return conversation.ConversationResult( + response=intent_response, conversation_id=conversation_id + ) + + def _async_generate_prompt(self, raw_prompt: str) -> str: + """Generate a prompt for the user.""" + return template.Template(raw_prompt, self.hass).async_render( + { + "ha_name": self.hass.config.location_name, + }, + parse_result=False, + ) diff --git a/homeassistant/components/google_generative_ai_conversation/manifest.json b/homeassistant/components/google_generative_ai_conversation/manifest.json index 5bafa9c43de..5f233936637 100644 --- a/homeassistant/components/google_generative_ai_conversation/manifest.json +++ b/homeassistant/components/google_generative_ai_conversation/manifest.json @@ -1,6 +1,7 @@ { "domain": "google_generative_ai_conversation", "name": "Google Generative AI Conversation", + "after_dependencies": ["assist_pipeline"], "codeowners": ["@tronikos"], "config_flow": true, "dependencies": ["conversation"], diff --git a/tests/components/google_generative_ai_conversation/conftest.py b/tests/components/google_generative_ai_conversation/conftest.py index c377a469df0..d5b4e8672e3 100644 --- a/tests/components/google_generative_ai_conversation/conftest.py +++ b/tests/components/google_generative_ai_conversation/conftest.py @@ -16,6 +16,7 @@ def mock_config_entry(hass): """Mock a config entry.""" entry = MockConfigEntry( domain="google_generative_ai_conversation", + title="Google Generative AI Conversation", data={ "api_key": "bla", }, diff --git a/tests/components/google_generative_ai_conversation/snapshots/test_conversation.ambr b/tests/components/google_generative_ai_conversation/snapshots/test_conversation.ambr new file mode 100644 index 00000000000..bf37fe0f2d9 --- /dev/null +++ b/tests/components/google_generative_ai_conversation/snapshots/test_conversation.ambr @@ -0,0 +1,169 @@ +# serializer version: 1 +# name: test_default_prompt[None] + list([ + tuple( + '', + tuple( + ), + dict({ + 'generation_config': dict({ + 'max_output_tokens': 150, + 'temperature': 0.9, + 'top_k': 1, + 'top_p': 1.0, + }), + 'model_name': 'models/gemini-pro', + }), + ), + tuple( + '().start_chat', + tuple( + ), + dict({ + 'history': list([ + dict({ + 'parts': ''' + This smart home is controlled by Home Assistant. + + An overview of the areas and the devices in this smart home: + + Test Area: + - Test Device (Test Model) + + Test Area 2: + - Test Device 2 + - Test Device 3 (Test Model 3A) + - Test Device 4 + - 1 (3) + + Answer the user's questions about the world truthfully. + + If the user wants to control a device, reject the request and suggest using the Home Assistant app. + ''', + 'role': 'user', + }), + dict({ + 'parts': 'Ok', + 'role': 'model', + }), + ]), + }), + ), + tuple( + '().start_chat().send_message_async', + tuple( + 'hello', + ), + dict({ + }), + ), + ]) +# --- +# name: test_default_prompt[conversation.google_generative_ai_conversation] + list([ + tuple( + '', + tuple( + ), + dict({ + 'generation_config': dict({ + 'max_output_tokens': 150, + 'temperature': 0.9, + 'top_k': 1, + 'top_p': 1.0, + }), + 'model_name': 'models/gemini-pro', + }), + ), + tuple( + '().start_chat', + tuple( + ), + dict({ + 'history': list([ + dict({ + 'parts': ''' + This smart home is controlled by Home Assistant. + + An overview of the areas and the devices in this smart home: + + Test Area: + - Test Device (Test Model) + + Test Area 2: + - Test Device 2 + - Test Device 3 (Test Model 3A) + - Test Device 4 + - 1 (3) + + Answer the user's questions about the world truthfully. + + If the user wants to control a device, reject the request and suggest using the Home Assistant app. + ''', + 'role': 'user', + }), + dict({ + 'parts': 'Ok', + 'role': 'model', + }), + ]), + }), + ), + tuple( + '().start_chat().send_message_async', + tuple( + 'hello', + ), + dict({ + }), + ), + ]) +# --- +# name: test_generate_content_service_with_image + list([ + tuple( + '', + tuple( + ), + dict({ + 'model_name': 'gemini-pro-vision', + }), + ), + tuple( + '().generate_content_async', + tuple( + list([ + 'Describe this image from my doorbell camera', + dict({ + 'data': b'image bytes', + 'mime_type': 'image/jpeg', + }), + ]), + ), + dict({ + }), + ), + ]) +# --- +# name: test_generate_content_service_without_images + list([ + tuple( + '', + tuple( + ), + dict({ + 'model_name': 'gemini-pro', + }), + ), + tuple( + '().generate_content_async', + tuple( + list([ + 'Write an opening speech for a Home Assistant release party', + ]), + ), + dict({ + }), + ), + ]) +# --- diff --git a/tests/components/google_generative_ai_conversation/test_conversation.py b/tests/components/google_generative_ai_conversation/test_conversation.py new file mode 100644 index 00000000000..719aab725fb --- /dev/null +++ b/tests/components/google_generative_ai_conversation/test_conversation.py @@ -0,0 +1,352 @@ +"""Tests for the Google Generative AI Conversation integration conversation platform.""" + +from unittest.mock import AsyncMock, MagicMock, patch + +from google.api_core.exceptions import ClientError +import pytest +from syrupy.assertion import SnapshotAssertion + +from homeassistant.components import conversation +from homeassistant.core import Context, HomeAssistant +from homeassistant.exceptions import HomeAssistantError +from homeassistant.helpers import area_registry as ar, device_registry as dr, intent + +from tests.common import MockConfigEntry + + +@pytest.mark.parametrize( + "agent_id", [None, "conversation.google_generative_ai_conversation"] +) +async def test_default_prompt( + hass: HomeAssistant, + mock_config_entry: MockConfigEntry, + mock_init_component, + area_registry: ar.AreaRegistry, + device_registry: dr.DeviceRegistry, + snapshot: SnapshotAssertion, + agent_id: str, +) -> None: + """Test that the default prompt works.""" + entry = MockConfigEntry(title=None) + entry.add_to_hass(hass) + for i in range(3): + area_registry.async_create(f"{i}Empty Area") + + if agent_id is None: + agent_id = mock_config_entry.entry_id + + device_registry.async_get_or_create( + config_entry_id=entry.entry_id, + connections={("test", "1234")}, + name="Test Device", + manufacturer="Test Manufacturer", + model="Test Model", + suggested_area="Test Area", + ) + for i in range(3): + device_registry.async_get_or_create( + config_entry_id=entry.entry_id, + connections={("test", f"{i}abcd")}, + name="Test Service", + manufacturer="Test Manufacturer", + model="Test Model", + suggested_area="Test Area", + entry_type=dr.DeviceEntryType.SERVICE, + ) + device_registry.async_get_or_create( + config_entry_id=entry.entry_id, + connections={("test", "5678")}, + name="Test Device 2", + manufacturer="Test Manufacturer 2", + model="Device 2", + suggested_area="Test Area 2", + ) + device_registry.async_get_or_create( + config_entry_id=entry.entry_id, + connections={("test", "9876")}, + name="Test Device 3", + manufacturer="Test Manufacturer 3", + model="Test Model 3A", + suggested_area="Test Area 2", + ) + device_registry.async_get_or_create( + config_entry_id=entry.entry_id, + connections={("test", "qwer")}, + name="Test Device 4", + suggested_area="Test Area 2", + ) + device = device_registry.async_get_or_create( + config_entry_id=entry.entry_id, + connections={("test", "9876-disabled")}, + name="Test Device 3", + manufacturer="Test Manufacturer 3", + model="Test Model 3A", + suggested_area="Test Area 2", + ) + device_registry.async_update_device( + device.id, disabled_by=dr.DeviceEntryDisabler.USER + ) + device_registry.async_get_or_create( + config_entry_id=entry.entry_id, + connections={("test", "9876-no-name")}, + manufacturer="Test Manufacturer NoName", + model="Test Model NoName", + suggested_area="Test Area 2", + ) + device_registry.async_get_or_create( + config_entry_id=entry.entry_id, + connections={("test", "9876-integer-values")}, + name=1, + manufacturer=2, + model=3, + suggested_area="Test Area 2", + ) + with patch("google.generativeai.GenerativeModel") as mock_model: + mock_model.return_value.start_chat.return_value = AsyncMock() + result = await conversation.async_converse( + hass, + "hello", + None, + Context(), + agent_id=agent_id, + ) + + assert result.response.response_type == intent.IntentResponseType.ACTION_DONE + assert [tuple(mock_call) for mock_call in mock_model.mock_calls] == snapshot + + +async def test_error_handling( + hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component +) -> None: + """Test that the default prompt works.""" + with patch("google.generativeai.GenerativeModel") as mock_model: + mock_chat = AsyncMock() + mock_model.return_value.start_chat.return_value = mock_chat + mock_chat.send_message_async.side_effect = ClientError("") + result = await conversation.async_converse( + hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id + ) + + assert result.response.response_type == intent.IntentResponseType.ERROR, result + assert result.response.error_code == "unknown", result + + +async def test_template_error( + hass: HomeAssistant, mock_config_entry: MockConfigEntry +) -> None: + """Test that template error handling works.""" + hass.config_entries.async_update_entry( + mock_config_entry, + options={ + "prompt": "talk like a {% if True %}smarthome{% else %}pirate please.", + }, + ) + with ( + patch( + "google.generativeai.get_model", + ), + patch("google.generativeai.GenerativeModel"), + ): + await hass.config_entries.async_setup(mock_config_entry.entry_id) + await hass.async_block_till_done() + result = await conversation.async_converse( + hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id + ) + + assert result.response.response_type == intent.IntentResponseType.ERROR, result + assert result.response.error_code == "unknown", result + + +async def test_conversation_agent( + hass: HomeAssistant, + mock_config_entry: MockConfigEntry, + mock_init_component, +) -> None: + """Test GoogleGenerativeAIAgent.""" + agent = conversation.get_agent_manager(hass).async_get_agent( + mock_config_entry.entry_id + ) + assert agent.supported_languages == "*" + + +async def test_generate_content_service_without_images( + hass: HomeAssistant, + mock_config_entry: MockConfigEntry, + mock_init_component, + snapshot: SnapshotAssertion, +) -> None: + """Test generate content service.""" + stubbed_generated_content = ( + "I'm thrilled to welcome you all to the release " + "party for the latest version of Home Assistant!" + ) + + with patch("google.generativeai.GenerativeModel") as mock_model: + mock_response = MagicMock() + mock_response.text = stubbed_generated_content + mock_model.return_value.generate_content_async = AsyncMock( + return_value=mock_response + ) + response = await hass.services.async_call( + "google_generative_ai_conversation", + "generate_content", + {"prompt": "Write an opening speech for a Home Assistant release party"}, + blocking=True, + return_response=True, + ) + + assert response == { + "text": stubbed_generated_content, + } + assert [tuple(mock_call) for mock_call in mock_model.mock_calls] == snapshot + + +async def test_generate_content_service_with_image( + hass: HomeAssistant, + mock_config_entry: MockConfigEntry, + mock_init_component, + snapshot: SnapshotAssertion, +) -> None: + """Test generate content service.""" + stubbed_generated_content = ( + "A mail carrier is at your front door delivering a package" + ) + + with ( + patch("google.generativeai.GenerativeModel") as mock_model, + patch( + "homeassistant.components.google_generative_ai_conversation.Path.read_bytes", + return_value=b"image bytes", + ), + patch("pathlib.Path.exists", return_value=True), + patch.object(hass.config, "is_allowed_path", return_value=True), + ): + mock_response = MagicMock() + mock_response.text = stubbed_generated_content + mock_model.return_value.generate_content_async = AsyncMock( + return_value=mock_response + ) + response = await hass.services.async_call( + "google_generative_ai_conversation", + "generate_content", + { + "prompt": "Describe this image from my doorbell camera", + "image_filename": "doorbell_snapshot.jpg", + }, + blocking=True, + return_response=True, + ) + + assert response == { + "text": stubbed_generated_content, + } + assert [tuple(mock_call) for mock_call in mock_model.mock_calls] == snapshot + + +@pytest.mark.usefixtures("mock_init_component") +async def test_generate_content_service_error( + hass: HomeAssistant, + mock_config_entry: MockConfigEntry, +) -> None: + """Test generate content service handles errors.""" + with ( + patch("google.generativeai.GenerativeModel") as mock_model, + pytest.raises( + HomeAssistantError, match="Error generating content: None reason" + ), + ): + mock_model.return_value.generate_content_async = AsyncMock( + side_effect=ClientError("reason") + ) + await hass.services.async_call( + "google_generative_ai_conversation", + "generate_content", + {"prompt": "write a story about an epic fail"}, + blocking=True, + return_response=True, + ) + + +async def test_generate_content_service_with_image_not_allowed_path( + hass: HomeAssistant, + mock_config_entry: MockConfigEntry, + mock_init_component, + snapshot: SnapshotAssertion, +) -> None: + """Test generate content service with an image in a not allowed path.""" + with ( + patch("pathlib.Path.exists", return_value=True), + patch.object(hass.config, "is_allowed_path", return_value=False), + pytest.raises( + HomeAssistantError, + match=( + "Cannot read `doorbell_snapshot.jpg`, no access to path; " + "`allowlist_external_dirs` may need to be adjusted in " + "`configuration.yaml`" + ), + ), + ): + await hass.services.async_call( + "google_generative_ai_conversation", + "generate_content", + { + "prompt": "Describe this image from my doorbell camera", + "image_filename": "doorbell_snapshot.jpg", + }, + blocking=True, + return_response=True, + ) + + +async def test_generate_content_service_with_image_not_exists( + hass: HomeAssistant, + mock_config_entry: MockConfigEntry, + mock_init_component, + snapshot: SnapshotAssertion, +) -> None: + """Test generate content service with an image that does not exist.""" + with ( + patch("pathlib.Path.exists", return_value=True), + patch.object(hass.config, "is_allowed_path", return_value=True), + patch("pathlib.Path.exists", return_value=False), + pytest.raises( + HomeAssistantError, match="`doorbell_snapshot.jpg` does not exist" + ), + ): + await hass.services.async_call( + "google_generative_ai_conversation", + "generate_content", + { + "prompt": "Describe this image from my doorbell camera", + "image_filename": "doorbell_snapshot.jpg", + }, + blocking=True, + return_response=True, + ) + + +async def test_generate_content_service_with_non_image( + hass: HomeAssistant, + mock_config_entry: MockConfigEntry, + mock_init_component, + snapshot: SnapshotAssertion, +) -> None: + """Test generate content service with a non image.""" + with ( + patch("pathlib.Path.exists", return_value=True), + patch.object(hass.config, "is_allowed_path", return_value=True), + patch("pathlib.Path.exists", return_value=True), + pytest.raises( + HomeAssistantError, match="`doorbell_snapshot.mp4` is not an image" + ), + ): + await hass.services.async_call( + "google_generative_ai_conversation", + "generate_content", + { + "prompt": "Describe this image from my doorbell camera", + "image_filename": "doorbell_snapshot.mp4", + }, + blocking=True, + return_response=True, + ) diff --git a/tests/components/google_generative_ai_conversation/test_init.py b/tests/components/google_generative_ai_conversation/test_init.py index 07254be9e3f..daae8582594 100644 --- a/tests/components/google_generative_ai_conversation/test_init.py +++ b/tests/components/google_generative_ai_conversation/test_init.py @@ -6,158 +6,12 @@ from google.api_core.exceptions import ClientError import pytest from syrupy.assertion import SnapshotAssertion -from homeassistant.components import conversation -from homeassistant.core import Context, HomeAssistant +from homeassistant.core import HomeAssistant from homeassistant.exceptions import HomeAssistantError -from homeassistant.helpers import area_registry as ar, device_registry as dr, intent from tests.common import MockConfigEntry -async def test_default_prompt( - hass: HomeAssistant, - mock_config_entry: MockConfigEntry, - mock_init_component, - area_registry: ar.AreaRegistry, - device_registry: dr.DeviceRegistry, - snapshot: SnapshotAssertion, -) -> None: - """Test that the default prompt works.""" - entry = MockConfigEntry(title=None) - entry.add_to_hass(hass) - for i in range(3): - area_registry.async_create(f"{i}Empty Area") - - device_registry.async_get_or_create( - config_entry_id=entry.entry_id, - connections={("test", "1234")}, - name="Test Device", - manufacturer="Test Manufacturer", - model="Test Model", - suggested_area="Test Area", - ) - for i in range(3): - device_registry.async_get_or_create( - config_entry_id=entry.entry_id, - connections={("test", f"{i}abcd")}, - name="Test Service", - manufacturer="Test Manufacturer", - model="Test Model", - suggested_area="Test Area", - entry_type=dr.DeviceEntryType.SERVICE, - ) - device_registry.async_get_or_create( - config_entry_id=entry.entry_id, - connections={("test", "5678")}, - name="Test Device 2", - manufacturer="Test Manufacturer 2", - model="Device 2", - suggested_area="Test Area 2", - ) - device_registry.async_get_or_create( - config_entry_id=entry.entry_id, - connections={("test", "9876")}, - name="Test Device 3", - manufacturer="Test Manufacturer 3", - model="Test Model 3A", - suggested_area="Test Area 2", - ) - device_registry.async_get_or_create( - config_entry_id=entry.entry_id, - connections={("test", "qwer")}, - name="Test Device 4", - suggested_area="Test Area 2", - ) - device = device_registry.async_get_or_create( - config_entry_id=entry.entry_id, - connections={("test", "9876-disabled")}, - name="Test Device 3", - manufacturer="Test Manufacturer 3", - model="Test Model 3A", - suggested_area="Test Area 2", - ) - device_registry.async_update_device( - device.id, disabled_by=dr.DeviceEntryDisabler.USER - ) - device_registry.async_get_or_create( - config_entry_id=entry.entry_id, - connections={("test", "9876-no-name")}, - manufacturer="Test Manufacturer NoName", - model="Test Model NoName", - suggested_area="Test Area 2", - ) - device_registry.async_get_or_create( - config_entry_id=entry.entry_id, - connections={("test", "9876-integer-values")}, - name=1, - manufacturer=2, - model=3, - suggested_area="Test Area 2", - ) - with patch("google.generativeai.GenerativeModel") as mock_model: - mock_model.return_value.start_chat.return_value = AsyncMock() - result = await conversation.async_converse( - hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id - ) - - assert result.response.response_type == intent.IntentResponseType.ACTION_DONE - assert [tuple(mock_call) for mock_call in mock_model.mock_calls] == snapshot - - -async def test_error_handling( - hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component -) -> None: - """Test that the default prompt works.""" - with patch("google.generativeai.GenerativeModel") as mock_model: - mock_chat = AsyncMock() - mock_model.return_value.start_chat.return_value = mock_chat - mock_chat.send_message_async.side_effect = ClientError("") - result = await conversation.async_converse( - hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id - ) - - assert result.response.response_type == intent.IntentResponseType.ERROR, result - assert result.response.error_code == "unknown", result - - -async def test_template_error( - hass: HomeAssistant, mock_config_entry: MockConfigEntry -) -> None: - """Test that template error handling works.""" - hass.config_entries.async_update_entry( - mock_config_entry, - options={ - "prompt": "talk like a {% if True %}smarthome{% else %}pirate please.", - }, - ) - with ( - patch( - "google.generativeai.get_model", - ), - patch("google.generativeai.GenerativeModel"), - ): - await hass.config_entries.async_setup(mock_config_entry.entry_id) - await hass.async_block_till_done() - result = await conversation.async_converse( - hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id - ) - - assert result.response.response_type == intent.IntentResponseType.ERROR, result - assert result.response.error_code == "unknown", result - - -async def test_conversation_agent( - hass: HomeAssistant, - mock_config_entry: MockConfigEntry, - mock_init_component, -) -> None: - """Test GoogleGenerativeAIAgent.""" - agent = conversation.get_agent_manager(hass).async_get_agent( - mock_config_entry.entry_id - ) - assert agent.supported_languages == "*" - - async def test_generate_content_service_without_images( hass: HomeAssistant, mock_config_entry: MockConfigEntry, From b35a6d68ab26bbe7a523aefc6e485f6f601e4d29 Mon Sep 17 00:00:00 2001 From: Allen Porter Date: Mon, 29 Apr 2024 01:20:20 +0000 Subject: [PATCH 2/9] Rename agent to entity --- .../google_generative_ai_conversation/conversation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/homeassistant/components/google_generative_ai_conversation/conversation.py b/homeassistant/components/google_generative_ai_conversation/conversation.py index 673bd152613..778b5831144 100644 --- a/homeassistant/components/google_generative_ai_conversation/conversation.py +++ b/homeassistant/components/google_generative_ai_conversation/conversation.py @@ -42,11 +42,11 @@ async def async_setup_entry( async_add_entities: AddEntitiesCallback, ) -> None: """Set up conversation entities.""" - agent = GoogleGenerativeAIAgent(hass, config_entry) + agent = GoogleGenerativeAIConversationEntity(hass, config_entry) async_add_entities([agent]) -class GoogleGenerativeAIAgent( +class GoogleGenerativeAIConversationEntity( conversation.ConversationEntity, conversation.AbstractConversationAgent ): """Google Generative AI conversation agent.""" From 5d63d3bf830f83a340ed198f9f703f074d18e19b Mon Sep 17 00:00:00 2001 From: Allen Porter Date: Mon, 29 Apr 2024 01:22:59 +0000 Subject: [PATCH 3/9] Add assist_pipeline after dependencies. --- homeassistant/components/ollama/manifest.json | 1 + 1 file changed, 1 insertion(+) diff --git a/homeassistant/components/ollama/manifest.json b/homeassistant/components/ollama/manifest.json index 6b16ae667f1..7afaaa3dbd4 100644 --- a/homeassistant/components/ollama/manifest.json +++ b/homeassistant/components/ollama/manifest.json @@ -1,6 +1,7 @@ { "domain": "ollama", "name": "Ollama", + "after_dependencies": ["assist_pipeline"], "codeowners": ["@synesthesiam"], "config_flow": true, "dependencies": ["conversation"], From 0616aa16e81382d9b014fe9d6b2bc024132e638f Mon Sep 17 00:00:00 2001 From: Allen Porter Date: Mon, 29 Apr 2024 01:24:20 +0000 Subject: [PATCH 4/9] Revert ollama changes --- homeassistant/components/ollama/manifest.json | 1 - 1 file changed, 1 deletion(-) diff --git a/homeassistant/components/ollama/manifest.json b/homeassistant/components/ollama/manifest.json index 7afaaa3dbd4..6b16ae667f1 100644 --- a/homeassistant/components/ollama/manifest.json +++ b/homeassistant/components/ollama/manifest.json @@ -1,7 +1,6 @@ { "domain": "ollama", "name": "Ollama", - "after_dependencies": ["assist_pipeline"], "codeowners": ["@synesthesiam"], "config_flow": true, "dependencies": ["conversation"], From 98990b34c5c2af99b60840d01ddf5c132e5282ce Mon Sep 17 00:00:00 2001 From: Allen Porter Date: Mon, 29 Apr 2024 14:12:49 +0000 Subject: [PATCH 5/9] Don't copy service tests to conversation_test.py --- .../test_conversation.py | 186 +----------------- 1 file changed, 1 insertion(+), 185 deletions(-) diff --git a/tests/components/google_generative_ai_conversation/test_conversation.py b/tests/components/google_generative_ai_conversation/test_conversation.py index 719aab725fb..c20bddc1aad 100644 --- a/tests/components/google_generative_ai_conversation/test_conversation.py +++ b/tests/components/google_generative_ai_conversation/test_conversation.py @@ -1,6 +1,6 @@ """Tests for the Google Generative AI Conversation integration conversation platform.""" -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import AsyncMock, patch from google.api_core.exceptions import ClientError import pytest @@ -8,7 +8,6 @@ from syrupy.assertion import SnapshotAssertion from homeassistant.components import conversation from homeassistant.core import Context, HomeAssistant -from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers import area_registry as ar, device_registry as dr, intent from tests.common import MockConfigEntry @@ -167,186 +166,3 @@ async def test_conversation_agent( mock_config_entry.entry_id ) assert agent.supported_languages == "*" - - -async def test_generate_content_service_without_images( - hass: HomeAssistant, - mock_config_entry: MockConfigEntry, - mock_init_component, - snapshot: SnapshotAssertion, -) -> None: - """Test generate content service.""" - stubbed_generated_content = ( - "I'm thrilled to welcome you all to the release " - "party for the latest version of Home Assistant!" - ) - - with patch("google.generativeai.GenerativeModel") as mock_model: - mock_response = MagicMock() - mock_response.text = stubbed_generated_content - mock_model.return_value.generate_content_async = AsyncMock( - return_value=mock_response - ) - response = await hass.services.async_call( - "google_generative_ai_conversation", - "generate_content", - {"prompt": "Write an opening speech for a Home Assistant release party"}, - blocking=True, - return_response=True, - ) - - assert response == { - "text": stubbed_generated_content, - } - assert [tuple(mock_call) for mock_call in mock_model.mock_calls] == snapshot - - -async def test_generate_content_service_with_image( - hass: HomeAssistant, - mock_config_entry: MockConfigEntry, - mock_init_component, - snapshot: SnapshotAssertion, -) -> None: - """Test generate content service.""" - stubbed_generated_content = ( - "A mail carrier is at your front door delivering a package" - ) - - with ( - patch("google.generativeai.GenerativeModel") as mock_model, - patch( - "homeassistant.components.google_generative_ai_conversation.Path.read_bytes", - return_value=b"image bytes", - ), - patch("pathlib.Path.exists", return_value=True), - patch.object(hass.config, "is_allowed_path", return_value=True), - ): - mock_response = MagicMock() - mock_response.text = stubbed_generated_content - mock_model.return_value.generate_content_async = AsyncMock( - return_value=mock_response - ) - response = await hass.services.async_call( - "google_generative_ai_conversation", - "generate_content", - { - "prompt": "Describe this image from my doorbell camera", - "image_filename": "doorbell_snapshot.jpg", - }, - blocking=True, - return_response=True, - ) - - assert response == { - "text": stubbed_generated_content, - } - assert [tuple(mock_call) for mock_call in mock_model.mock_calls] == snapshot - - -@pytest.mark.usefixtures("mock_init_component") -async def test_generate_content_service_error( - hass: HomeAssistant, - mock_config_entry: MockConfigEntry, -) -> None: - """Test generate content service handles errors.""" - with ( - patch("google.generativeai.GenerativeModel") as mock_model, - pytest.raises( - HomeAssistantError, match="Error generating content: None reason" - ), - ): - mock_model.return_value.generate_content_async = AsyncMock( - side_effect=ClientError("reason") - ) - await hass.services.async_call( - "google_generative_ai_conversation", - "generate_content", - {"prompt": "write a story about an epic fail"}, - blocking=True, - return_response=True, - ) - - -async def test_generate_content_service_with_image_not_allowed_path( - hass: HomeAssistant, - mock_config_entry: MockConfigEntry, - mock_init_component, - snapshot: SnapshotAssertion, -) -> None: - """Test generate content service with an image in a not allowed path.""" - with ( - patch("pathlib.Path.exists", return_value=True), - patch.object(hass.config, "is_allowed_path", return_value=False), - pytest.raises( - HomeAssistantError, - match=( - "Cannot read `doorbell_snapshot.jpg`, no access to path; " - "`allowlist_external_dirs` may need to be adjusted in " - "`configuration.yaml`" - ), - ), - ): - await hass.services.async_call( - "google_generative_ai_conversation", - "generate_content", - { - "prompt": "Describe this image from my doorbell camera", - "image_filename": "doorbell_snapshot.jpg", - }, - blocking=True, - return_response=True, - ) - - -async def test_generate_content_service_with_image_not_exists( - hass: HomeAssistant, - mock_config_entry: MockConfigEntry, - mock_init_component, - snapshot: SnapshotAssertion, -) -> None: - """Test generate content service with an image that does not exist.""" - with ( - patch("pathlib.Path.exists", return_value=True), - patch.object(hass.config, "is_allowed_path", return_value=True), - patch("pathlib.Path.exists", return_value=False), - pytest.raises( - HomeAssistantError, match="`doorbell_snapshot.jpg` does not exist" - ), - ): - await hass.services.async_call( - "google_generative_ai_conversation", - "generate_content", - { - "prompt": "Describe this image from my doorbell camera", - "image_filename": "doorbell_snapshot.jpg", - }, - blocking=True, - return_response=True, - ) - - -async def test_generate_content_service_with_non_image( - hass: HomeAssistant, - mock_config_entry: MockConfigEntry, - mock_init_component, - snapshot: SnapshotAssertion, -) -> None: - """Test generate content service with a non image.""" - with ( - patch("pathlib.Path.exists", return_value=True), - patch.object(hass.config, "is_allowed_path", return_value=True), - patch("pathlib.Path.exists", return_value=True), - pytest.raises( - HomeAssistantError, match="`doorbell_snapshot.mp4` is not an image" - ), - ): - await hass.services.async_call( - "google_generative_ai_conversation", - "generate_content", - { - "prompt": "Describe this image from my doorbell camera", - "image_filename": "doorbell_snapshot.mp4", - }, - blocking=True, - return_response=True, - ) From ee9efa0455a854ed539f4524ae4848a0d0b6a82b Mon Sep 17 00:00:00 2001 From: Allen Porter Date: Mon, 29 Apr 2024 14:22:04 +0000 Subject: [PATCH 6/9] Move logger and cleanup snapshots --- .../__init__.py | 6 +- .../const.py | 3 + .../conversation.py | 14 ++--- .../snapshots/test_init.ambr | 60 ------------------- 4 files changed, 11 insertions(+), 72 deletions(-) diff --git a/homeassistant/components/google_generative_ai_conversation/__init__.py b/homeassistant/components/google_generative_ai_conversation/__init__.py index a828c316544..d4a6c5bfa69 100644 --- a/homeassistant/components/google_generative_ai_conversation/__init__.py +++ b/homeassistant/components/google_generative_ai_conversation/__init__.py @@ -3,7 +3,6 @@ from __future__ import annotations from functools import partial -import logging import mimetypes from pathlib import Path @@ -24,9 +23,8 @@ from homeassistant.exceptions import ConfigEntryNotReady, HomeAssistantError from homeassistant.helpers import config_validation as cv from homeassistant.helpers.typing import ConfigType -from .const import CONF_CHAT_MODEL, CONF_PROMPT, DEFAULT_CHAT_MODEL, DOMAIN +from .const import CONF_CHAT_MODEL, CONF_PROMPT, DEFAULT_CHAT_MODEL, DOMAIN, LOGGER -_LOGGER = logging.getLogger(__name__) SERVICE_GENERATE_CONTENT = "generate_content" CONF_IMAGE_FILENAME = "image_filename" @@ -106,7 +104,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: ) except ClientError as err: if err.reason == "API_KEY_INVALID": - _LOGGER.error("Invalid API key: %s", err) + LOGGER.error("Invalid API key: %s", err) return False raise ConfigEntryNotReady(err) from err diff --git a/homeassistant/components/google_generative_ai_conversation/const.py b/homeassistant/components/google_generative_ai_conversation/const.py index 2798b85f308..0619ee45e4b 100644 --- a/homeassistant/components/google_generative_ai_conversation/const.py +++ b/homeassistant/components/google_generative_ai_conversation/const.py @@ -1,6 +1,9 @@ """Constants for the Google Generative AI Conversation integration.""" +import logging + DOMAIN = "google_generative_ai_conversation" +LOGGER = logging.getLogger(__name__) CONF_PROMPT = "prompt" DEFAULT_PROMPT = """This smart home is controlled by Home Assistant. diff --git a/homeassistant/components/google_generative_ai_conversation/conversation.py b/homeassistant/components/google_generative_ai_conversation/conversation.py index 778b5831144..71bb46de890 100644 --- a/homeassistant/components/google_generative_ai_conversation/conversation.py +++ b/homeassistant/components/google_generative_ai_conversation/conversation.py @@ -2,7 +2,6 @@ from __future__ import annotations -import logging from typing import Literal from google.api_core.exceptions import ClientError @@ -31,10 +30,9 @@ from .const import ( DEFAULT_TEMPERATURE, DEFAULT_TOP_K, DEFAULT_TOP_P, + LOGGER, ) -_LOGGER = logging.getLogger(__name__) - async def async_setup_entry( hass: HomeAssistant, @@ -97,7 +95,7 @@ class GoogleGenerativeAIConversationEntity( ), }, ) - _LOGGER.debug("Model: %s", model) + LOGGER.debug("Model: %s", model) if user_input.conversation_id in self.history: conversation_id = user_input.conversation_id @@ -109,7 +107,7 @@ class GoogleGenerativeAIConversationEntity( try: prompt = self._async_generate_prompt(raw_prompt) except TemplateError as err: - _LOGGER.error("Error rendering prompt: %s", err) + LOGGER.error("Error rendering prompt: %s", err) intent_response = intent.IntentResponse(language=user_input.language) intent_response.async_set_error( intent.IntentResponseErrorCode.UNKNOWN, @@ -122,7 +120,7 @@ class GoogleGenerativeAIConversationEntity( messages[0] = {"role": "user", "parts": prompt} messages[1] = {"role": "model", "parts": "Ok"} - _LOGGER.debug("Input: '%s' with history: %s", user_input.text, messages) + LOGGER.debug("Input: '%s' with history: %s", user_input.text, messages) chat = model.start_chat(history=messages) try: @@ -133,7 +131,7 @@ class GoogleGenerativeAIConversationEntity( genai_types.BlockedPromptException, genai_types.StopCandidateException, ) as err: - _LOGGER.error("Error sending message: %s", err) + LOGGER.error("Error sending message: %s", err) intent_response = intent.IntentResponse(language=user_input.language) intent_response.async_set_error( intent.IntentResponseErrorCode.UNKNOWN, @@ -143,7 +141,7 @@ class GoogleGenerativeAIConversationEntity( response=intent_response, conversation_id=conversation_id ) - _LOGGER.debug("Response: %s", chat_response.parts) + LOGGER.debug("Response: %s", chat_response.parts) self.history[conversation_id] = chat.history intent_response = intent.IntentResponse(language=user_input.language) diff --git a/tests/components/google_generative_ai_conversation/snapshots/test_init.ambr b/tests/components/google_generative_ai_conversation/snapshots/test_init.ambr index 5347c010f28..aba3f35eb19 100644 --- a/tests/components/google_generative_ai_conversation/snapshots/test_init.ambr +++ b/tests/components/google_generative_ai_conversation/snapshots/test_init.ambr @@ -1,64 +1,4 @@ # serializer version: 1 -# name: test_default_prompt - list([ - tuple( - '', - tuple( - ), - dict({ - 'generation_config': dict({ - 'max_output_tokens': 150, - 'temperature': 0.9, - 'top_k': 1, - 'top_p': 1.0, - }), - 'model_name': 'models/gemini-pro', - }), - ), - tuple( - '().start_chat', - tuple( - ), - dict({ - 'history': list([ - dict({ - 'parts': ''' - This smart home is controlled by Home Assistant. - - An overview of the areas and the devices in this smart home: - - Test Area: - - Test Device (Test Model) - - Test Area 2: - - Test Device 2 - - Test Device 3 (Test Model 3A) - - Test Device 4 - - 1 (3) - - Answer the user's questions about the world truthfully. - - If the user wants to control a device, reject the request and suggest using the Home Assistant app. - ''', - 'role': 'user', - }), - dict({ - 'parts': 'Ok', - 'role': 'model', - }), - ]), - }), - ), - tuple( - '().start_chat().send_message_async', - tuple( - 'hello', - ), - dict({ - }), - ), - ]) -# --- # name: test_generate_content_service_with_image list([ tuple( From 51bc0db67b268d1cb04f9995bdd09cb1ed104e08 Mon Sep 17 00:00:00 2001 From: Allen Porter Date: Mon, 29 Apr 2024 14:23:01 +0000 Subject: [PATCH 7/9] Move property after init --- .../google_generative_ai_conversation/conversation.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/homeassistant/components/google_generative_ai_conversation/conversation.py b/homeassistant/components/google_generative_ai_conversation/conversation.py index 71bb46de890..d87a0a14f9e 100644 --- a/homeassistant/components/google_generative_ai_conversation/conversation.py +++ b/homeassistant/components/google_generative_ai_conversation/conversation.py @@ -59,6 +59,11 @@ class GoogleGenerativeAIConversationEntity( self._attr_name = entry.title self._attr_unique_id = entry.entry_id + @property + def supported_languages(self) -> list[str] | Literal["*"]: + """Return a list of supported languages.""" + return MATCH_ALL + async def async_added_to_hass(self) -> None: """When entity is added to Home Assistant.""" await super().async_added_to_hass() @@ -72,11 +77,6 @@ class GoogleGenerativeAIConversationEntity( conversation.async_unset_agent(self.hass, self.entry) await super().async_will_remove_from_hass() - @property - def supported_languages(self) -> list[str] | Literal["*"]: - """Return a list of supported languages.""" - return MATCH_ALL - async def async_process( self, user_input: conversation.ConversationInput ) -> conversation.ConversationResult: From 9595161348570d41c5729d4d5df257c4a1eadd8f Mon Sep 17 00:00:00 2001 From: Allen Porter Date: Mon, 29 Apr 2024 14:26:58 +0000 Subject: [PATCH 8/9] Set logger to use package --- .../components/google_generative_ai_conversation/const.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/homeassistant/components/google_generative_ai_conversation/const.py b/homeassistant/components/google_generative_ai_conversation/const.py index 0619ee45e4b..f7e71989efd 100644 --- a/homeassistant/components/google_generative_ai_conversation/const.py +++ b/homeassistant/components/google_generative_ai_conversation/const.py @@ -3,7 +3,7 @@ import logging DOMAIN = "google_generative_ai_conversation" -LOGGER = logging.getLogger(__name__) +LOGGER = logging.getLogger(__package__) CONF_PROMPT = "prompt" DEFAULT_PROMPT = """This smart home is controlled by Home Assistant. From 4420a9d7bb4dda21f9e96dc94dbeabaa1a88f048 Mon Sep 17 00:00:00 2001 From: Allen Porter Date: Fri, 3 May 2024 14:47:07 +0000 Subject: [PATCH 9/9] Cleanup hass from constructor --- .../google_generative_ai_conversation/conversation.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/homeassistant/components/google_generative_ai_conversation/conversation.py b/homeassistant/components/google_generative_ai_conversation/conversation.py index d87a0a14f9e..828bcc84c21 100644 --- a/homeassistant/components/google_generative_ai_conversation/conversation.py +++ b/homeassistant/components/google_generative_ai_conversation/conversation.py @@ -40,7 +40,7 @@ async def async_setup_entry( async_add_entities: AddEntitiesCallback, ) -> None: """Set up conversation entities.""" - agent = GoogleGenerativeAIConversationEntity(hass, config_entry) + agent = GoogleGenerativeAIConversationEntity(config_entry) async_add_entities([agent]) @@ -51,9 +51,8 @@ class GoogleGenerativeAIConversationEntity( _attr_has_entity_name = True - def __init__(self, hass: HomeAssistant, entry: ConfigEntry) -> None: + def __init__(self, entry: ConfigEntry) -> None: """Initialize the agent.""" - self.hass = hass self.entry = entry self.history: dict[str, list[genai_types.ContentType]] = {} self._attr_name = entry.title