diff --git a/integrations/opensearch/src/haystack_integrations/document_stores/opensearch/document_store.py b/integrations/opensearch/src/haystack_integrations/document_stores/opensearch/document_store.py index 6cb5295f0..7deaad285 100644 --- a/integrations/opensearch/src/haystack_integrations/document_stores/opensearch/document_store.py +++ b/integrations/opensearch/src/haystack_integrations/document_stores/opensearch/document_store.py @@ -9,6 +9,7 @@ from haystack.dataclasses import Document from haystack.document_stores.errors import DocumentStoreError, DuplicateDocumentError from haystack.document_stores.types import DuplicatePolicy +from haystack.utils.auth import Secret from opensearchpy import OpenSearch from opensearchpy.helpers import bulk @@ -45,7 +46,10 @@ def __init__( mappings: Optional[Dict[str, Any]] = None, settings: Optional[Dict[str, Any]] = DEFAULT_SETTINGS, create_index: bool = True, - http_auth: Any = None, + http_auth: Any = ( + Secret.from_env_var("OPENSEARCH_USERNAME", strict=False), # noqa: B008 + Secret.from_env_var("OPENSEARCH_PASSWORD", strict=False), # noqa: B008 + ), use_ssl: Optional[bool] = None, verify_certs: Optional[bool] = None, timeout: Optional[int] = None, @@ -79,6 +83,7 @@ def __init__( - a tuple of (username, password) - a list of [username, password] - a string of "username:password" + If not provided, will read values from OPENSEARCH_USERNAME and OPENSEARCH_PASSWORD environment variables. For AWS authentication with `Urllib3HttpConnection` pass an instance of `AWSAuth`. Defaults to None :param use_ssl: Whether to use SSL. Defaults to None @@ -97,6 +102,17 @@ def __init__( self._mappings = mappings or self._get_default_mappings() self._settings = settings self._create_index = create_index + self._http_auth_are_secrets = False + + # Handle authentication + if isinstance(http_auth, (tuple, list)) and len(http_auth) == 2: # noqa: PLR2004 + username, password = http_auth + if isinstance(username, Secret) and isinstance(password, Secret): + self._http_auth_are_secrets = True + username_val = username.resolve_value() + password_val = password.resolve_value() + http_auth = [username_val, password_val] if username_val and password_val else None + self._http_auth = http_auth self._use_ssl = use_ssl self._verify_certs = verify_certs @@ -174,15 +190,24 @@ def create_index( self.client.indices.create(index=index, body={"mappings": mappings, "settings": settings}) def to_dict(self) -> Dict[str, Any]: - # This is not the best solution to serialise this class but is the fastest to implement. - # Not all kwargs types can be serialised to text so this can fail. We must serialise each - # type explicitly to handle this properly. """ Serializes the component to a dictionary. :returns: Dictionary with serialized data. """ + # Handle http_auth serialization + if isinstance(self._http_auth, list) and self._http_auth_are_secrets: + # Recreate the Secret objects for serialization + http_auth = [ + Secret.from_env_var("OPENSEARCH_USERNAME", strict=False).to_dict(), + Secret.from_env_var("OPENSEARCH_PASSWORD", strict=False).to_dict(), + ] + elif isinstance(self._http_auth, AWSAuth): + http_auth = self._http_auth.to_dict() + else: + http_auth = self._http_auth + return default_to_dict( self, hosts=self._hosts, @@ -194,7 +219,7 @@ def to_dict(self) -> Dict[str, Any]: settings=self._settings, create_index=self._create_index, return_embedding=self._return_embedding, - http_auth=self._http_auth.to_dict() if isinstance(self._http_auth, AWSAuth) else self._http_auth, + http_auth=http_auth, use_ssl=self._use_ssl, verify_certs=self._verify_certs, timeout=self._timeout, @@ -208,14 +233,16 @@ def from_dict(cls, data: Dict[str, Any]) -> "OpenSearchDocumentStore": :param data: Dictionary to deserialize from. - :returns: Deserialized component. """ - if http_auth := data.get("init_parameters", {}).get("http_auth"): + init_params = data.get("init_parameters", {}) + if http_auth := init_params.get("http_auth"): if isinstance(http_auth, dict): - data["init_parameters"]["http_auth"] = AWSAuth.from_dict(http_auth) - + init_params["http_auth"] = AWSAuth.from_dict(http_auth) + elif isinstance(http_auth, (tuple, list)): + are_secrets = all(isinstance(item, dict) and "type" in item for item in http_auth) + init_params["http_auth"] = [Secret.from_dict(item) for item in http_auth] if are_secrets else http_auth return default_from_dict(cls, data) def count_documents(self) -> int: diff --git a/integrations/opensearch/tests/test_document_store.py b/integrations/opensearch/tests/test_document_store.py index 043f59891..82c21e6fe 100644 --- a/integrations/opensearch/tests/test_document_store.py +++ b/integrations/opensearch/tests/test_document_store.py @@ -263,6 +263,66 @@ def test_to_dict_aws_auth(self, _mock_opensearch_client, monkeypatch: pytest.Mon }, } + @patch("haystack_integrations.document_stores.opensearch.document_store.OpenSearch") + def test_init_with_env_var_secrets(self, _mock_opensearch_client, monkeypatch): + """Test the default initialization using environment variables""" + monkeypatch.setenv("OPENSEARCH_USERNAME", "user") + monkeypatch.setenv("OPENSEARCH_PASSWORD", "pass") + + document_store = OpenSearchDocumentStore(hosts="testhost") + assert document_store.client + _mock_opensearch_client.assert_called_once() + assert _mock_opensearch_client.call_args[1]["http_auth"] == ["user", "pass"] + + @patch("haystack_integrations.document_stores.opensearch.document_store.OpenSearch") + def test_init_with_missing_env_vars(self, _mock_opensearch_client): + """Test that auth is None when environment variables are missing""" + document_store = OpenSearchDocumentStore(hosts="testhost") + assert document_store.client + _mock_opensearch_client.assert_called_once() + assert _mock_opensearch_client.call_args[1]["http_auth"] is None + + @patch("haystack_integrations.document_stores.opensearch.document_store.OpenSearch") + def test_to_dict_with_env_var_secrets(self, _mock_opensearch_client, monkeypatch): + """Test serialization with environment variables""" + monkeypatch.setenv("OPENSEARCH_USERNAME", "user") + monkeypatch.setenv("OPENSEARCH_PASSWORD", "pass") + + document_store = OpenSearchDocumentStore(hosts="testhost") + serialized = document_store.to_dict() + + assert "http_auth" in serialized["init_parameters"] + auth = serialized["init_parameters"]["http_auth"] + assert isinstance(auth, list) + assert len(auth) == 2 + # Check that we have two Secret dictionaries with correct env vars + assert auth[0]["type"] == "env_var" + assert auth[0]["env_vars"] == ["OPENSEARCH_USERNAME"] + assert auth[1]["type"] == "env_var" + assert auth[1]["env_vars"] == ["OPENSEARCH_PASSWORD"] + + @patch("haystack_integrations.document_stores.opensearch.document_store.OpenSearch") + def test_from_dict_with_env_var_secrets(self, _mock_opensearch_client, monkeypatch): + """Test deserialization with environment variables""" + # Set environment variables so the secrets resolve properly + monkeypatch.setenv("OPENSEARCH_USERNAME", "user") + monkeypatch.setenv("OPENSEARCH_PASSWORD", "pass") + + data = { + "type": "haystack_integrations.document_stores.opensearch.document_store.OpenSearchDocumentStore", + "init_parameters": { + "hosts": "testhost", + "http_auth": [ + {"type": "env_var", "env_vars": ["OPENSEARCH_USERNAME"], "strict": False}, + {"type": "env_var", "env_vars": ["OPENSEARCH_PASSWORD"], "strict": False}, + ], + }, + } + document_store = OpenSearchDocumentStore.from_dict(data) + assert document_store.client + _mock_opensearch_client.assert_called_once() + assert _mock_opensearch_client.call_args[1]["http_auth"] == ["user", "pass"] + @pytest.mark.integration class TestDocumentStore(DocumentStoreBaseTests):