from collections import defaultdict
import itertools
import random
from typing import Any, List, Tuple
import networkx as nx
import time

def hopcroft_karp(arks: List[Tuple[Any, Any]]) -> List[Tuple[Any, Any]]:
    M = set()
    V1, V2 = zip(*arks)

    unmatched_v1 = set(V1)
    unmatched_v2 = set(V2)

    while unmatched_v1:
        # Baue BFS-Baum mit ungematchten r aus V1
        G = nx.Graph()
        
        # Füge nicht gewählte Matches als gerichtete Kanten von V2 zu V1 ein
        G.add_edges_from(tuple(reversed(a)) for a in arks if a not in M)
        # Füge gewählte Matches als gerichtete Kanten von V1 zu V2 ein
        G.add_edges_from(M)
        
        # Augmentiere: Suche Pfade von unmatched Knoten aus V2 nach unmatched Knoten in V1
        for s in unmatched_v2.intersection(set(G.nodes())):
            for r in unmatched_v1.intersection(set(G.nodes())):
                try:
                    path = nx.shortest_path(G, source=s, target=r, weight=None)
                except (nx.NetworkXNoPath, nx.NodeNotFound):
                    # kein Pfad s->...->r möglich
                    continue

                # Lösche den Pfad aus G
                G.remove_nodes_from(path)

                # Augmentiere:
                matches = list(zip(path, path[1:]))
                # 1. Füge neue Matchings hinzu. Diese befinden sich an den "ungeraden" Stellen im Array (0, 2, 4, ...) und sind verkehrt herum
                for new_match in matches[::2]:
                    M.add(tuple(reversed(new_match)))
                # 2. Lösche die gewählten Matchings entlang des Pfades. Diese befinden sich an den "geraden" Stellen im Array (1, 3, 5, ...).
                for old_match in matches[1::2]:
                    M.discard(old_match)
                
                unmatched_v1 -= set(path)
                unmatched_v2 -= set(path)
                break
        
        print(f"best matching yet: {M}")

    return list(M)

def bip_matching_sat(arks: List[Tuple[Any, Any]]) -> List[Tuple[Any, Any]]:

    """
    Löst bip matching mittels SAT-Solver
    """

    import pysat # install via python-sat
    import pysat.solvers as pysol

    sat = pysol.Solver("Minicard")
    vars = dict()
    incident = defaultdict(list)

    comp1, comp2 = set(), set()
    for v, w in arks:
        vars[(v, w)] = len(vars) + 1
        incident[v].append(vars[(v, w)])
        incident[w].append(vars[(v, w)])
        comp1.add(v)
        comp2.add(w)
    
    for v in comp1:
        sat.add_atmost(incident[v], 1)
    
    for w in comp2:
        sat.add_atmost(incident[w], 1)
    
    def add_atleast(clause, k):
        assert k >= 1
        if k == 1:
            sat.add_clause(clause)
        else:
            sat.add_atmost([-x for x in clause], len(clause) - k)
    

    all_vars = list(vars.values())
    n = len(comp1) + len(comp2)
    last_sol = None

    # find maximum matching
    for i in itertools.count(1):
        add_atleast(all_vars, i)

        if not sat.solve():
            if not last_sol:
                return list()
            else:
                return [m for i, m in enumerate(arks) if (i+1) in last_sol]
        
        last_sol = set(sat.get_model())
        print(f"Solution for {i} matches found: {last_sol}")