#!/usr/bin/env python3

import argparse
import xml.etree.ElementTree as ET

VERSION = "1.1.0"

BASE_TYPE_MAP = {
    "fixed": "Wl.fixed_t",
    "string": "string",
    "array": "Wl.Array",
    "fd": "int32",
    "int": "int32",
    "uint": "uint32",
}


def snake_to_pascal(name: str) -> str:
    return "".join(word.capitalize() for word in name.split("_"))


def snake_to_camel(name: str) -> str:
    parts = name.split("_")
    return parts[0] + "".join(word.capitalize() for word in parts[1:])


def snake_to_screaming_snake(name: str) -> str:
    return name.upper()


def map_vala_type(type_name, interface, interface_attr=None, enum_attr=None):
    if type_name == "object" and interface_attr:
        if interface_attr.startswith("wl_"):
            return f"Wl.{snake_to_pascal(interface_attr[3:])}"
        return map_vala_type(interface_attr, interface)

    if type_name == "new_id":
        return map_vala_type(interface_attr, interface) if interface_attr != None else "void*"

    if enum_attr and type_name in ("int", "uint"):
        if "." in enum_attr:
            enum_interface, enum_name = enum_attr.split(".")
            return f"{map_vala_type(enum_interface, enum_interface)}{snake_to_pascal(enum_name)}"
        return f"{map_vala_type(interface, interface)}{snake_to_pascal(enum_attr)}"

    return BASE_TYPE_MAP.get(type_name, snake_to_pascal(type_name))


def generate_docs(node, f_vapi, indent=False):
    description_tag = node.find("description")
    lines = []

    if (
        description_tag is not None
        and description_tag.text
        and description_tag.text.strip()
    ):
        lines = [line.strip() for line in description_tag.text.strip().splitlines()]
    elif node.get("summary"):
        lines = [" ".join(node.get("summary").strip().split())]

    if lines:
        indent_str = "  " if indent else ""
        f_vapi.write(f"{indent_str}/**\n")
        for line in lines:
            f_vapi.write(f"{indent_str} * {line}\n")
        f_vapi.write(f"{indent_str} */\n")


def generate_version(node, f_vapi, indent=False):
    since = node.get("since")
    deprecated_since = node.get("deprecated-since")
    attrs = {}
    if since:
        attrs["since"] = since
    if deprecated_since:
        attrs["deprecated"] = "true"
        attrs["deprecated_since"] = deprecated_since
    if not attrs:
        return
    
    attr_str = ", ".join(f'{k}="{v}"' for k, v in attrs.items())
    indent_str = "  " if indent else ""
    f_vapi.write(f'{indent_str}[Version ({attr_str})]\n')


def generate_parameters(args, interface_name, return_new_id=True):
    params = []
    return_type = "void"
    for arg in args:
        arg_name = arg.get("name")
        arg_type = arg.get("type")
        arg_interface = arg.get("interface")
        arg_enum = arg.get("enum")
        arg_allow_null = arg.get("allow-null")

        vala_type = map_vala_type(
            arg_type, interface_name, interface_attr=arg_interface, enum_attr=arg_enum
        )
        if arg_allow_null == "true":
            vala_type = f"{vala_type}?"
        if arg_type == "new_id" and return_new_id:
            return_type = vala_type
        else:
            params.append(f"{vala_type} {arg_name}")

    return return_type, ", ".join(params)


def generate_requests(f_vapi, interface_name_snake, interface_name_vala, requests):
    for request in requests:
        request_name_snake = request.get("name")
        generate_docs(request, f_vapi, True)
        generate_version(request, f_vapi, True)

        return_type, params_str = generate_parameters(
            request.findall("arg"), interface_name_snake
        )

        if request.get("type") == "destructor" or request.get("destroyer") == "true" or request_name_snake == "destroy":
            f_vapi.write("  [DestroysInstance]\n")

        f_vapi.write(f"  public {return_type} {request_name_snake}({params_str});\n")


def generate_events(f_vapi, interface_name_snake, interface_name_vala, events):
    f_vapi.write(
        f'[CCode (cname="struct {interface_name_snake}_listener", has_type_id=false)]\n'
    )
    f_vapi.write(f"public struct {interface_name_vala}Listener {{\n")
    for event in events:
        name_snake = event.get("name")
        f_vapi.write(
            f"  public {interface_name_vala}Listener{snake_to_pascal(name_snake)} {name_snake};\n"
        )
    f_vapi.write("}\n\n")

    for event in events:
        event_name_vala = snake_to_pascal(event.get("name"))
        generate_docs(event, f_vapi)
        generate_version(event, f_vapi)
        return_type, params_str = generate_parameters(
            event.findall("arg"), interface_name_snake, False
        )
        f_vapi.write("[CCode (has_target=false, has_typedef=false)]\n")
        f_vapi.write(
            f"public delegate void {interface_name_vala}Listener{event_name_vala}(void *data, {interface_name_vala} {interface_name_snake}"
            f"{', ' + params_str if params_str else ''});\n\n"
        )


def generate_enum(
    f_vapi, enum, interface_name_snake, interface_name_vala, cheader_filename
):
    enum_name_snake = enum.get("name")
    enum_name_vala = snake_to_pascal(enum_name_snake)
    generate_docs(enum, f_vapi, False)
    generate_version(enum, f_vapi, False)
    
    f_vapi.write(
        f'[CCode (cprefix="{snake_to_screaming_snake(interface_name_snake)}_{snake_to_screaming_snake(enum_name_snake)}_", '
        f'cname="enum {interface_name_snake}_{enum_name_snake}", cheader_filename="{cheader_filename}")]'
        "\n"
    )
    if enum.get("bitfield") == "true":
        f_vapi.write("[Flags]\n")

    f_vapi.write(f"public enum {interface_name_vala}{enum_name_vala} {{\n")
    for entry in enum.findall("entry"):
        enum_value_vala = snake_to_screaming_snake(entry.get("name"))
        generate_docs(entry, f_vapi, True)
        generate_version(entry, f_vapi, True)
        f_vapi.write(f"  {enum_value_vala},\n")
    f_vapi.write("}\n\n")
    f_vapi.flush()

def generate_vapi_from_xml(protocol_file, output_vapi_file, cheader_filename):
    try:
        tree = ET.parse(protocol_file)
        root = tree.getroot()

        with open(output_vapi_file, "w") as f_vapi:
            f_vapi.write(f"// Generated VAPI file using wl-vapi-gen {VERSION}\n\n")

            for interface in root.findall("interface"):
                interface_name_snake = interface.get("name")
                interface_name_vala = map_vala_type(
                    interface_name_snake, interface_name_snake
                )

                free_function = f"{interface_name_snake}_destroy"
                for request in interface.findall("request"):
                    if (
                        request.get("type") == "destructor"
                        or request.get("destroyer") == "true"
                        or request.get("name") == "destroy"
                    ):
                        free_function = f"{interface_name_snake}_{request.get('name')}"
                        break

                generate_docs(interface, f_vapi)
                f_vapi.write(
                    f'[CCode (cheader_filename="{cheader_filename}", '
                    f'cname="struct {interface_name_snake}", '
                    f'cprefix="{interface_name_snake}_", free_function="{free_function}")]'
                    f"\n[Compact]\npublic class {interface_name_vala} : Wl.Proxy {{\n"
                )

                f_vapi.write(f'  [CCode(cname="{interface_name_snake}_interface")]\n')
                f_vapi.write("  public static Wl.Interface iface;\n\n")

                f_vapi.write("  public void set_user_data(void* user_data);\n")
                f_vapi.write("  public void* get_user_data();\n")
                f_vapi.write("  public uint32 get_version();\n\n")

                requests = interface.findall("request")
                if requests:
                    generate_requests(
                        f_vapi, interface_name_snake, interface_name_vala, requests
                    )

                events = interface.findall("event")
                if events:
                    f_vapi.write(
                        f"  public int add_listener({interface_name_vala}Listener listener, void* data);\n"
                    )
                f_vapi.write("}\n\n")

                if events:
                    generate_events(
                        f_vapi, interface_name_snake, interface_name_vala, events
                    )

                for enum in interface.findall("enum"):
                    generate_enum(
                        f_vapi,
                        enum,
                        interface_name_snake,
                        interface_name_vala,
                        cheader_filename,
                    )

    except FileNotFoundError:
        print(f"Error: Protocol file not found: {protocol_file}")
    except ET.ParseError as e:
        print(f"Error parsing XML: {e}")
    except Exception as e:
        print(f"An error occurred: {e}")

def main():
    parser = argparse.ArgumentParser(
        description="Generate VAPI files from Wayland protocol XML."
    )
    parser.add_argument('--version', action='version', version=f'%(prog)s {VERSION}')
    parser.add_argument(
        "--protocol", required=True, help="Path to the Wayland protocol XML file."
    )
    parser.add_argument("--vapi", required=True, help="Path to the output VAPI file.")
    parser.add_argument(
        "--cheader",
        required=True,
        help="The c header file name generated by wayland-scanner.",
    )
    args = parser.parse_args()

    generate_vapi_from_xml(args.protocol, args.vapi, args.cheader)


if __name__ == "__main__":
    main()
