diff --git a/src/dotenv/__init__.py b/src/dotenv/__init__.py index 7f4c631b..ce68b28b 100644 --- a/src/dotenv/__init__.py +++ b/src/dotenv/__init__.py @@ -1,7 +1,7 @@ from typing import Any, Optional -from .main import (dotenv_values, find_dotenv, get_key, load_dotenv, set_key, - unset_key) +from .main import (dotenv_values, find_dotenv, get_key, load_dotenv, set_header, + set_key, unset_key) def load_ipython_extension(ipython: Any) -> None: @@ -42,6 +42,7 @@ def get_cli_string( __all__ = ['get_cli_string', 'load_dotenv', 'dotenv_values', + 'set_header', 'get_key', 'set_key', 'unset_key', diff --git a/src/dotenv/main.py b/src/dotenv/main.py index 1848d602..0fbe48d2 100644 --- a/src/dotenv/main.py +++ b/src/dotenv/main.py @@ -5,6 +5,7 @@ import shutil import sys import tempfile +import textwrap from collections import OrderedDict from contextlib import contextmanager from typing import IO, Dict, Iterable, Iterator, Mapping, Optional, Tuple, Union @@ -396,3 +397,34 @@ def dotenv_values( override=True, encoding=encoding, ).dict() + + +def set_header( + dotenv_path: StrPath, + header: str, + encoding: Optional[str] = "utf-8", +) -> Tuple[bool, Optional[str]]: + """ + Adds or Updates a header in the .env file + + Parameters: + dotenv_path: Absolute or relative path to .env file. + header: The desired header block + encoding: Encoding to be used to read the file. + Returns: + Bool: True if at least one environment variable is set else False + Str: The header that was written + """ + with rewrite(dotenv_path, encoding=encoding) as (source, dest): + if not header or not header.strip(): + logger.info("Ignoring empty header.") + return False, header + + lines = textwrap.wrap(header.replace("\n", " "), width=60) + header = "".join(f"# {line}\n" for line in lines) + dest.write(header) + + text = "".join(atom for atom in source.readlines() if not atom.startswith("#")) + dest.write(f"{text}\n") + + return True, header diff --git a/tests/test_main.py b/tests/test_main.py index 2d63eec1..1e951a06 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -396,3 +396,53 @@ def test_dotenv_values_file_stream(dotenv_path): result = dotenv.dotenv_values(stream=f) assert result == {"a": "b"} + + +@pytest.mark.parametrize( + "header", + [ + "", + " ", + None, + ], +) +def test_set_header_empty(dotenv_path, header): + logger = logging.getLogger("dotenv.main") + + with mock.patch.object(logger, "info") as mock_info: + result, *_ = dotenv.set_header(dotenv_path, header) + assert not result + + mock_info.assert_called() + + +@pytest.mark.parametrize( + "new_header, expected, expected_header, old_header, content", + [ + ("single-line input", True, "# single-line input\n", "", "a=b\nc=d"), + ("multi-line\ninput", True, "# multi-line input\n", "", "a=b\nc=d"), + ("new header", True, "# new header\n", "# old header", "a=b\nc=d"), + ( + " ".join("x" * 57 for _ in range(2)), + True, + "".join(f"# {'x' * 57}\n" for _ in range(2)), + "", + "a=b", + ), + ], +) +def test_set_header( + dotenv_path, new_header, expected, expected_header, old_header, content +): + logger = logging.getLogger("dotenv.main") + dotenv_path.write_text(f"{old_header}\n{content}") + + with mock.patch.object(logger, "warning") as mock_warning: + result, written = dotenv.set_header(dotenv_path, new_header) + assert result == expected + assert written == expected_header + + text = dotenv_path.read_text() + assert content in text + assert expected_header in text + mock_warning.assert_not_called()