from itertools import product def combinations(address): result = [] floating = address.count('X') combs = product('01', repeat=floating) for c in combs: l = list(c) r = [] for a in address: if a == 'X': r.append(l.pop()) else: r.append(a) result.append(r) return result def apply_mask_v1(mask, num): result = [] num = bin(num)[2:].zfill(36) for n, m in zip(num, mask): if m != 'X': result.append(m) else: result.append(n) return int(''.join(result), 2) def apply_mask_v2(mask, num): result = [] num = bin(num)[2:].zfill(36) for n, m in zip(num, mask): if m == '0': result.append(n) else: result.append(m) return [int(''.join(r), 2) for r in combinations(result)] def solve_a(data): mem = {} for line in data: mask = line[0] for inst in line[1:]: addr, num = [int(i) for i in inst.split(',')] mem[addr] = apply_mask_v1(mask, num) return sum([mem[a] for a in mem]) def solve_b(data): mem = {} for line in data: mask = line[0] for inst in line[1:]: addr, num = [int(i) for i in inst.split(',')] addresses = apply_mask_v2(mask, addr) for a in addresses: mem[a] = num return sum([mem[a] for a in mem]) data = [line.strip() for line in open('input').read().split('mask = ') if line] for i, line in enumerate(data): data[i] = line.replace('mem[', '').replace('] = ', ',').split() print(solve_a(data)) print(solve_b(data))