diff --git a/pyproject.toml b/pyproject.toml index 339f39d..9882505 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "ssb-konjunk" -version = "0.0.3" +version = "0.0.4" description = "SSB Konjunk" authors = ["Edvard Garmannslund "] license = "MIT" diff --git a/src/ssb_konjunk/prompts.py b/src/ssb_konjunk/prompts.py index 23375d3..04f0d05 100644 --- a/src/ssb_konjunk/prompts.py +++ b/src/ssb_konjunk/prompts.py @@ -6,6 +6,7 @@ import re from calendar import monthrange from datetime import datetime +from typing import Any def _input_valid_int() -> int: @@ -203,3 +204,42 @@ def delta_month(month: int, periods: int) -> int: elif new_month < 1: new_month = new_month + 12 return new_month + + +def iterate_years_months( + start_year: int, end_year: int, start_month: int, end_month: int +) -> Any: + """Function to iterate over years and month. + + Allows you to select start year, start month, end year and end month + + Args: + start_year: Int for start year. + start_month: Int for start month. + end_year: Int for end year. + end_month: Int for end month. + + Yields: + Any: A tuple containing the year and month for each combination. + + Raises: + ValueError: If start year is bigger than end year. + ValueError: If month is invalid number. + ValueError: If end month is bigger than start and only iterating on one year. + """ + if start_year > end_year: + raise ValueError("Start year must be less than or equal to end year") + if start_month < 1 or start_month > 12 or end_month < 1 or end_month > 12: + raise ValueError("Month must be between 1 and 12") + if start_year == end_year and start_month > end_month: + raise ValueError( + "If iterating in same year start month must be less than end month." + ) + + for year in range(start_year, end_year + 1): + for month in range(1, 13): + if (year == start_year and month < start_month) or ( + year == end_year and month > end_month + ): + continue + yield year, month diff --git a/tests/test_prompts.py b/tests/test_prompts.py index 8ad1e31..61fbd05 100644 --- a/tests/test_prompts.py +++ b/tests/test_prompts.py @@ -5,6 +5,7 @@ from ssb_konjunk.prompts import days_in_month from ssb_konjunk.prompts import delta_month from ssb_konjunk.prompts import extract_start_end_dates +from ssb_konjunk.prompts import iterate_years_months """Test of function days in month""" @@ -90,3 +91,58 @@ def test_delta_month() -> None: delta_month(12, -12) with pytest.raises(ValueError): delta_month(12, 0) + + +"""Test of function iterate_years_months""" + + +def test_iterate_years_months_full_range() -> None: + # Test when providing a full range of years and months + expected_output = [ + (2024, 1), + (2024, 2), + (2024, 3), + (2024, 4), + (2024, 5), + (2024, 6), + (2024, 7), + (2024, 8), + (2024, 9), + (2024, 10), + (2024, 11), + (2024, 12), + (2025, 1), + ] + assert list(iterate_years_months(2024, 2025, 1, 1)) == expected_output + + +def test_iterate_years_months_partial_range() -> None: + # Test when providing a partial range of years and months + expected_output = [(2023, 11), (2023, 12), (2024, 1), (2024, 2)] + assert list(iterate_years_months(2023, 2024, 11, 2)) == expected_output + + +def test_iterate_years_months_one_period() -> None: + # Test when providing a partial range of years and months + expected_output = [(2024, 2)] + assert list(iterate_years_months(2024, 2024, 2, 2)) == expected_output + + +def test_iterate_years_months_invalid_range() -> None: + # Test when providing an invalid range where start year > end year + with pytest.raises(ValueError): + list(iterate_years_months(2024, 2022, 1, 12)) + + # Test when providing an invalid range where start month > end month + with pytest.raises(ValueError): + list(iterate_years_months(2024, 2024, 6, 1)) + + +def test_iterate_years_months_invalid_month() -> None: + # Test when providing an invalid month (greater than 12) + with pytest.raises(ValueError): + list(iterate_years_months(2022, 2024, 1, 13)) + + # Test when providing an invalid month (less than 1) + with pytest.raises(ValueError): + list(iterate_years_months(2022, 2024, 0, 12))