-
Notifications
You must be signed in to change notification settings - Fork 0
/
AoC2023_12.py
127 lines (103 loc) Β· 3.29 KB
/
AoC2023_12.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
#! /usr/bin/env python3
#
# Advent of Code 2023 Day 12
#
import sys
from functools import cache
from itertools import product
from typing import Callable
from aoc.common import InputData
from aoc.common import SolutionBase
from aoc.common import aoc_samples
Input = InputData
Output1 = int
Output2 = int
TEST = """\
???.### 1,1,3
.??..??...?##. 1,1,3
?#?#?#?#?#?#?#? 1,3,1,6
????.#...#... 4,1,1
????.######..#####. 1,6,5
?###???????? 3,2,1
"""
def count(s: str, counts: tuple[int, ...]) -> int:
_count.cache_clear()
return _count(s, counts, 0)
@cache
def _count(s: str, counts: tuple[int, ...], idx: int) -> int:
# base case: end of s
if s == "":
if len(counts) == 0 and idx == 0:
# no more groups required and not in group -> OK
return 1
else:
# still group(s) required or still in group -> NOK
return 0
# otherwise:
ans = 0
# '?' can be '.' or '#'
nxts = {".", "#"} if s[0] == "?" else {s[0]}
for nxt in nxts:
if nxt == "#":
# if '#': move to next char in current group
ans += _count(s[1:], counts, idx + 1)
elif nxt == ".":
# if '.' (between groups)
if idx > 0:
# was in group before
if len(counts) > 0 and idx == counts[0]:
# finished group matches required -> find next required
ans += _count(s[1:], counts[1:], 0)
else:
# was not in group: keep looking for next group
ans += _count(s[1:], counts, 0)
else:
# should not happen
assert False
return ans
class Solution(SolutionBase[Input, Output1, Output2]):
def parse_input(self, input_data: InputData) -> Input:
return input_data
def brute_force_count(self, s: str, counts: tuple[int, ...]) -> int:
ans = 0
total = sum(counts)
s_count = s.count("#")
for p in product("#.", repeat=s.count("?")):
if s_count + p.count("#") != total:
continue
pi = iter(p)
test = "".join(
s[i] if s[i] != "?" else next(pi) for i in range(len(s))
)
if tuple(len(_) for _ in test.split(".") if len(_) != 0) == counts:
ans += 1
return ans
def solve(
self, input: Input, f: Callable[[str], tuple[str, tuple[int, ...]]]
) -> int:
return sum(count(*f(line)) for line in input)
def part_1(self, input: Input) -> Output1:
def parse(line: str) -> tuple[str, tuple[int, ...]]:
s, w = line.split()
counts = tuple(int(_) for _ in w.split(","))
return s + ".", counts
return self.solve(input, parse)
def part_2(self, input: Input) -> Output2:
def parse(line: str) -> tuple[str, tuple[int, ...]]:
s, w = line.split()
counts = tuple(int(_) for _ in w.split(","))
return "?".join([s] * 5) + ".", counts * 5
return self.solve(input, parse)
@aoc_samples(
(
("part_1", TEST, 21),
("part_2", TEST, 525152),
)
)
def samples(self) -> None:
pass
solution = Solution(2023, 12)
def main() -> None:
solution.run(sys.argv)
if __name__ == "__main__":
main()