obsolete.computer

misc-scripts/spfflatten.py

File Type: text/x-script.python


import argparse
import ipaddress
from dns import resolver, exception as dns_exception

def print_verbose(msg):
    if VERBOSE:
        print(f"[VERBOSE] {msg}")

def get_spf(domain):
    print_verbose(f"Looking up TXT for {domain}")
    try:
        answers = resolver.resolve(domain, 'TXT')
        for rdata in answers:
            # Modern dnspython: bytes → str
            txt = b''.join(rdata.strings).decode('ascii').strip().replace('"', '')
            if txt.lower().startswith('v=spf1'):
                print_verbose(f"Found SPF for {domain}: {txt}")
                return txt
        raise ValueError(f"No SPF record found for {domain}")
    except dns_exception.DNSException as e:
        print_verbose(f"TXT lookup failed for {domain}: {e}")
        raise ValueError(f"DNS lookup failed for {domain}: {e}")

def collect_ips(domain, seen=None, ignore_ipv6=False):
    if seen is None:
        seen = set()
    if domain in seen:
        print_verbose(f"Cycle detected → stopping recursion at {domain}")
        raise ValueError(f"Cycle detected in SPF includes: {domain}")
    seen.add(domain)

    spf = get_spf(domain)
    parts = spf.lower().strip().split()
    if parts[0] != 'v=spf1':
        raise ValueError(f"Invalid SPF record for {domain}")

    ips = []
    redirect = None
    has_all = False
    i = 1
    while i < len(parts):
        part = parts[i]
        qualifier = '+'
        if part and part[0] in '+-~?':
            qualifier = part[0]
            part = part[1:]
        if qualifier != '+':
            i += 1
            continue

        if part.startswith('ip4:'):
            ip_str = part[4:]
            print_verbose(f"Found ip4: {ip_str}")
            if not ignore_ipv6 or ':' not in ip_str:
                ips.append(ip_str)
        elif part.startswith('ip6:'):
            ip_str = part[4:]
            print_verbose(f"Found ip6: {ip_str}")
            if not ignore_ipv6:
                ips.append(ip_str)
        elif part.startswith('include:'):
            sub_domain = part[8:]
            print_verbose(f"Recursing into include: {sub_domain}")
            sub_ips = collect_ips(sub_domain, seen, ignore_ipv6)
            ips.extend(sub_ips)
        elif part in ('a', 'mx') or part.startswith(('a:', 'a/', 'mx:', 'mx/')):
            type_ = 'a' if part.startswith(('a', 'a:')) else 'mx'
            if part in (type_, type_ + ':', type_ + '/'):
                resolve_domain = domain
                cidr_str = part[len(type_):] if ':' in part or '/' in part else ''
            else:
                rest = part[len(type_) + 1:]
                resolve_domain = rest.split('/')[0]
                cidr_str = '/' + '/'.join(rest.split('/')[1:]) if '/' in rest else ''

            cidr_parts = cidr_str.lstrip('/').split('/')
            cidr4 = int(cidr_parts[0]) if cidr_parts and cidr_parts[0] else None
            cidr6 = int(cidr_parts[1]) if len(cidr_parts) > 1 else cidr4

            if type_ == 'mx':
                print_verbose(f"Looking up MX for {resolve_domain}")
                try:
                    mx_answers = resolver.resolve(resolve_domain, 'MX')
                    hosts = [str(r.exchange).rstrip('.') for r in mx_answers]
                    print_verbose(f"MX records for {resolve_domain}: {', '.join(hosts)}")
                except dns_exception.DNSException as e:
                    print_verbose(f"MX lookup failed for {resolve_domain}: {e}")
                    hosts = []
            else:
                hosts = [resolve_domain]
                print_verbose(f"Using A/AAAA for domain {resolve_domain}")

            for host in hosts:
                for rdtype in ['A', 'AAAA']:
                    if ignore_ipv6 and rdtype == 'AAAA':
                        continue
                    print_verbose(f"Looking up {rdtype} for {host}")
                    try:
                        answers = resolver.resolve(host, rdtype)
                        for r in answers:
                            ip = str(r)
                            mask = cidr4 if rdtype == 'A' else cidr6
                            if mask is not None:
                                if rdtype == 'A':
                                    net = ipaddress.IPv4Network(f"{ip}/{mask}", strict=False)
                                else:
                                    net = ipaddress.IPv6Network(f"{ip}/{mask}", strict=False)
                                cidr_str = f"{net.network_address}/{mask}"
                                print_verbose(f"Resolved {rdtype} → {ip} → CIDR {cidr_str}")
                                ips.append(cidr_str)
                            else:
                                print_verbose(f"Resolved {rdtype} → {ip}")
                                ips.append(ip)
                    except dns_exception.DNSException as e:
                        print_verbose(f"{rdtype} lookup failed for {host}: {e}")
        elif part == 'all':
            has_all = True
            print_verbose("Found 'all' mechanism")
        elif part.startswith('redirect='):
            redirect = part[9:]
            print_verbose(f"Found redirect={redirect}")
        # Skip ptr, exists, etc.
        i += 1

    if redirect and not has_all:
        print_verbose(f"Following redirect to {redirect}")
        return collect_ips(redirect, seen, ignore_ipv6)

    return ips

def main():
    global VERBOSE
    parser = argparse.ArgumentParser(description="Flatten SPF record to IP addresses.")
    parser.add_argument("domain", help="The domain name to process.")
    parser.add_argument("--output", choices=["list", "spf"], default="list",
                        help="Output format: 'list' (comma-separated) or 'spf' (flattened SPF entry).")
    parser.add_argument("--ignore-ipv6", action="store_true", help="Ignore IPv6 entries.")
    parser.add_argument("-v", "--verbose", action="store_true", help="Show detailed DNS lookup information.")
    args = parser.parse_args()

    VERBOSE = args.verbose

    try:
        ips = collect_ips(args.domain, ignore_ipv6=args.ignore_ipv6)
        unique_ips = sorted(set(ips), key=lambda x: (':' in x, ipaddress.ip_network(x, strict=False)))
        if args.output == "list":
            print(",".join(unique_ips))
        elif args.output == "spf":
            spf_parts = []
            for ip in unique_ips:
                prefix = "ip6:" if ':' in ip else "ip4:"
                spf_parts.append(prefix + ip)
            print("v=spf1 " + " ".join(spf_parts) + " -all")
    except ValueError as e:
        print(f"Error: {e}")
        if VERBOSE:
            import traceback
            traceback.print_exc()

if __name__ == "__main__":
    main()

Meta