from pathlib import Path
import os
import sys
import pandas as pd
from datetime import datetime

configfile: "config.yaml"

# Read samples CSV
if os.path.exists(config["samples_csv"]):
    sample_data = pd.read_csv(config["samples_csv"]).set_index("ID", drop=False)
else:
    sys.exit(f"Error: samples.csv file '{config['samples_csv']}' does not exist")

def get_taxid(wildcards):
    return sample_data.loc[wildcards.ID, "taxid"]

# Validation functions
def validate_config():
    """Validate configuration based on target type"""
    target_type = config["target_config"]["target_type"]
    valid_targets = {"gene", "chloroplast", "mitochondrion", "ribosomal", "ribosomal_complete"}
    
    if target_type not in valid_targets:
        raise ValueError(f"Invalid target_type: {target_type}. Must be one of: {valid_targets}")
        
    # Validate parameters based on target type
    if target_type == "gene":
        required_params = {"protein_size", "nucleotide_size", "sequence_types", "gene_names"}
        missing_params = required_params - set(config["target_config"]["gene_params"].keys())
        if missing_params:
            raise ValueError(f"Missing required gene_params: {missing_params}")
            
        # Validate samples.csv exists when target_type is gene
        if not os.path.exists(config["samples_csv"]):
            raise FileNotFoundError(f"samples_csv file not found: {config['samples_csv']}")
    else:
        required_params = {"database", "min_sequences", "max_sequences"}
        missing_params = required_params - set(config["target_config"]["fetch_params"].keys())
        if missing_params:
            raise ValueError(f"Missing required fetch_params: {missing_params}")

    # Validate run_name exists
    if "run_name" not in config:
        raise ValueError("run_name must be specified in config.yaml")

# Validate configuration
validate_config()

#def get_final_outputs(wildcards):
#    """Determine final outputs based on target type"""
#    outputs = []
#    target_type = config["target_config"]["target_type"]
#    run_name = config["run_name"]
#    
#    if target_type == "gene":
#        # Handle gene targets
#        for gene in config["target_config"]["gene_params"]["gene_names"]:
#            gene_dir = os.path.join("results", gene, run_name)
#            outputs.extend([
#                os.path.join(gene_dir, "sequence_references.csv"),
#                os.path.join(gene_dir, "gene_fetch.log")
#            ])
#    else:
#        # Handle organelle/ribosomal targets
#        for sample_id in sample_data.index:
#            target_dir = os.path.join("results", target_type, run_name, sample_id)
#            if config["target_config"]["fetch_params"]["getorganelle"]:
#                outputs.extend([
#                    os.path.join(target_dir, "seed.fasta"),
#                    os.path.join(target_dir, "gene.fasta")
#                ])
#            else:
#                outputs.extend([
#                    os.path.join(target_dir, "fasta"),
#                    os.path.join(target_dir, "genbank")
#                ])
#    return outputs

def get_final_outputs(wildcards):
    """Determine final outputs based on target type"""
    outputs = []
    target_type = config["target_config"]["target_type"]
    run_name = config["run_name"]
    
    if target_type == "gene":
        # Handle gene targets
        for gene in config["target_config"]["gene_params"]["gene_names"]:
            gene_dir = os.path.join("results", gene, run_name)
            outputs.extend([
                os.path.join(gene_dir, "sequence_references.csv"),
                os.path.join(gene_dir, "gene_fetch.log")
            ])
    else:
        # Handle organelle/ribosomal targets
        for _, row in sample_data.iterrows():
            sample_id = row["ID"]
            taxid = row["taxid"]
            target_dir = os.path.join("results", target_type, run_name, str(sample_id), str(taxid))
            
            if config["target_config"]["fetch_params"].get("getorganelle", False):
                outputs.extend([
                    os.path.join(target_dir, "seed.fasta"),
                    os.path.join(target_dir, "gene.fasta"),
                    os.path.join(target_dir, "annotated_regions", "get_annotated_regions_from_gb.log")
                ])
            else:
                outputs.extend([
                    os.path.join(target_dir, "fasta"),
                    os.path.join(target_dir, "genbank")
                ])
    return outputs

rule all:
    input:
        get_final_outputs

rule fetch_gene_sequences:
    output:
        refs = os.path.join("results", "{gene}", config["run_name"], "sequence_references.csv"),
        log = os.path.join("results", "{gene}", config["run_name"], "gene_fetch.log")
    params:
        output_dir = lambda wildcards: os.path.join("results", wildcards.gene, config["run_name"]),
        protein_size = config["target_config"]["gene_params"]["protein_size"],
        nucleotide_size = config["target_config"]["gene_params"]["nucleotide_size"],
        email = config["email"],
        api_key = config["api_key"],
        sequence_type = config["target_config"]["gene_params"]["sequence_types"][0]
    log:
        os.path.join("logs", f"gene_fetch-{config['run_name']}-{{gene}}.log")    
    conda:
        "envs/fetch.yaml"
    wildcard_constraints:
        gene = "|".join(config["target_config"]["gene_params"]["gene_names"])
    shell:
        """
        python ./workflow/scripts/gene_fetch.py \
            {wildcards.gene} \
            {params.output_dir} \
            {config[samples_csv]} \
            --type {params.sequence_type} \
            --protein_size {params.protein_size} \
            --nucleotide_size {params.nucleotide_size} \
            2> {log} || {{ echo "Error: gene_fetch.py failed" >> {log}; exit 1; }}
            
        # Log success
        echo "gene_fetch successfully completed for gene {wildcards.gene}" >> {log}
        """

#PREVIOUS VERSION OF RULE
#rule fetch_organelle_sequences:
#    output:
#        seed = os.path.join("results", config["target_config"]["target_type"], config["run_name"], "{ID}",
#               "seed.fasta" if config["target_config"]["fetch_params"]["getorganelle"] else "fasta"),
#        gene = os.path.join("results", config["target_config"]["target_type"], config["run_name"], "{ID}",
#               "gene.fasta" if config["target_config"]["fetch_params"]["getorganelle"] else "genbank")
#    params:
#        output_dir = lambda wildcards: os.path.join("results", config["target_config"]["target_type"], config["run_name"], "{ID}"),
#        target_type = config["target_config"]["target_type"],
#        min_seq = config["target_config"]["fetch_params"]["min_sequences"],
#        max_seq = config["target_config"]["fetch_params"]["max_sequences"],
#        email = config["email"],
#        api_key = config["api_key"],
#        getorganelle = "--getorganelle" if config["target_config"]["fetch_params"].get("getorganelle", False) else "",
#        database = config["target_config"]["fetch_params"]["database"],
#        taxid = lambda wildcards: get_taxid(wildcards)
#    log:
#        os.path.join("logs", f"go_fetch-{config['run_name']}-{config['target_config']['target_type']}-{{ID}}.log")
#    conda:
#        "envs/fetch.yaml"
#    shell:
#        """
#        python ./workflow/scripts/go_fetch.py \
#            --taxonomy {params.taxid} \
#            --target {params.target_type} \
#            --db {params.database} \
#            --min {params.min_seq} \
#            --max {params.max_seq} \
#            --output {params.output_dir} \
#            --email {params.email} \
#            --api {params.api_key} \
#            {params.getorganelle} \
#            --overwrite \
#            2> {log}
#        """



#NEW VERSION OF RULE - Creates temp dir, Runs go_fetch.py with temp dir as output, Creates final dir structure,
#Moves all contents to final location, and Cleans up temporary directory.
rule fetch_organelle_sequences:
    output:
        # Using direct config values instead of wildcards
        seed = os.path.join("results", 
                          config["target_config"]["target_type"], 
                          config["run_name"], 
                          "{ID}", "{taxid}",
                          "seed.fasta" if config["target_config"]["fetch_params"]["getorganelle"] else "fasta"),
        gene = os.path.join("results", 
                          config["target_config"]["target_type"], 
                          config["run_name"], 
                          "{ID}", "{taxid}",
                          "gene.fasta" if config["target_config"]["fetch_params"]["getorganelle"] else "genbank"),
        annot_log = os.path.join("results", 
                          config["target_config"]["target_type"], 
                          config["run_name"], 
                          "{ID}", "{taxid}",
                          "annotated_regions", "get_annotated_regions_from_gb.log") if config["target_config"]["fetch_params"]["getorganelle"] else [],
        sample_log = os.path.join("results",
                          config["target_config"]["target_type"],
                          config["run_name"],
                          "{ID}", "{taxid}",
                          "go_fetch.log")
    params:
        temp_dir = lambda wildcards: f"temp/{wildcards.ID}_{wildcards.taxid}",
        final_dir = lambda wildcards: os.path.join("results", 
                                                  config["target_config"]["target_type"], 
                                                  config["run_name"], 
                                                  wildcards.ID, 
                                                  wildcards.taxid),
        target_type = config["target_config"]["target_type"],
        min_seq = config["target_config"]["fetch_params"]["min_sequences"],
        max_seq = config["target_config"]["fetch_params"]["max_sequences"],
        email = config["email"],
        api_key = config["api_key"],
        getorganelle = "--getorganelle" if config["target_config"]["fetch_params"].get("getorganelle", False) else "",
        database = config["target_config"]["fetch_params"]["database"],
        taxid = lambda wildcards: wildcards.taxid
    log:
        os.path.join("logs", f"go_fetch-{config['run_name']}-{config['target_config']['target_type']}-{{ID}}-{{taxid}}.log")
    conda:
        "envs/fetch.yaml"
    retries: 3
    resources:
        api_requests=1
    shell:
        """
        # Error handling function
        handle_error() {{
            error_message="$1"
            echo "Error: ${{error_message}}" >> {log}
            echo "Error: ${{error_message}}" >> {output.sample_log}
            echo "This job will be retried up to 3 times" >> {log}
            echo "This job will be retried up to 3 times" >> {output.sample_log}
            rm -rf {params.temp_dir}  # Cleanup on error
            exit 1
        }}

        # Function to wait between retries
        wait_and_retry() {{
            sleep_time=$(( ( RANDOM % 30 )  + 30 ))  # Random sleep between 30-60 seconds
            echo "Waiting ${{sleep_time}} seconds before retrying..." >> {log}
            echo "Waiting ${{sleep_time}} seconds before retrying..." >> {output.sample_log}
            sleep ${{sleep_time}}
        }}

        # Ensure clean start with retry
        for i in $(seq 1 3); do
            if rm -rf {params.temp_dir} 2>/dev/null; then
                break
            elif [ $i -eq 3 ]; then
                handle_error "Failed to clean temporary directory after 3 attempts"
            else
                wait_and_retry
            fi
        done

        # Create temporary directory with retry
        for i in $(seq 1 3); do
            if mkdir -p {params.temp_dir}; then
                break
            elif [ $i -eq 3 ]; then
                handle_error "Failed to create temporary directory after 3 attempts"
            else
                wait_and_retry
            fi
        done

        # Initialise the sample log file with timestamp
        echo "Starting go_fetch for ID {wildcards.ID} (taxid: {wildcards.taxid})" > {output.sample_log}

        # Run go_fetch with retry
        for i in $(seq 1 3); do
            if python ./workflow/scripts/go_fetch2.py \
                --taxonomy {params.taxid} \
                --target {params.target_type} \
                --db {params.database} \
                --min {params.min_seq} \
                --max {params.max_seq} \
                --output {params.temp_dir} \
                --email {params.email} \
                --api {params.api_key} \
                {params.getorganelle} \
                --overwrite \
                2> >(tee -a {log} {output.sample_log} >&2) \
                > >(tee -a {output.sample_log}); then
                # Check the actual output files exist before declaring success
                if [ -f "{params.temp_dir}/seed.fasta" ] && [ -f "{params.temp_dir}/gene.fasta" ]; then
                    break
                else
                    echo "Required output files not found despite successful exit" >> {log}
                    exit 1
                fi
            else
                if [ $i -eq 3 ]; then
                    handle_error "go_fetch.py failed after 3 attempts"
                else
                    echo "Attempt $i failed, retrying..." >> {log}
                    wait_and_retry
                fi
            fi
        done

        # Verify go_fetch output exists
        if [ ! -d "{params.temp_dir}" ]; then
            handle_error "go_fetch did not create output directory"
        fi

        # Create final directory with error handling
        mkdir -p {params.final_dir} || handle_error "Failed to create final directory"

        # Move files with error handling
        if [ -d "{params.temp_dir}" ] && [ "$(ls -A {params.temp_dir})" ]; then
            mv {params.temp_dir}/* {params.final_dir}/ || handle_error "Failed to move files to final location"
        else
            handle_error "No output files found in temporary directory"
        fi

        # Verify required files exist in final location
        if [ "{params.getorganelle}" != "" ]; then
            [ -f "{params.final_dir}/seed.fasta" ] || handle_error "seed.fasta not found in output"
            [ -f "{params.final_dir}/gene.fasta" ] || handle_error "gene.fasta not found in output"
            [ -f "{params.final_dir}/annotated_regions/get_annotated_regions_from_gb.log" ] || handle_error "annotated_regions log not found in output"
        else
            [ -d "{params.final_dir}/fasta" ] || handle_error "fasta directory not found in output"
            [ -d "{params.final_dir}/genbank" ] || handle_error "genbank directory not found in output"
        fi

        # Clean up temporary directory
        rm -rf {params.temp_dir} || handle_error "Failed to clean up temporary directory"
        
        # Log success to both log files
        echo "go_fetch successfully completed for ID {wildcards.ID}. (taxid: {wildcards.taxid})" >> {log}
        echo "go_fetch successfully completed for ID {wildcards.ID}. (taxid: {wildcards.taxid})" >> {output.sample_log}
        """


onstart:
    shell("""
        # Validate required directories exist
        mkdir -p logs results temp || {{ echo "Failed to create required directories"; exit 1; }}
        
        # Validate samples CSV exists and is readable
        if [ ! -r "{config[samples_csv]}" ]; then
            echo "Error: Cannot read samples CSV file: {config[samples_csv]}"
            exit 1
        fi
        
        # Validate conda environment
        if ! command -v python >/dev/null 2>&1; then
            echo "Error: Python not found in conda environment"
            exit 1
        fi
        """
    )

onsuccess:
    shell("rm -rf temp/")

onerror:
    shell("""
        echo "Workflow failed. Cleaning up temporary files..."
        rm -rf temp/
        """
    )
