import sys
import random
import argparse
import gzip

def reservoir_sample_fastq(input_files, output_files, target_reads):
    """
    Subsample extremely large FASTQ files keeping Read 1 and Read 2 pairs accurately synced.
    Uses standard algorithmic Reservoir Sampling to operate in 1 pass, allowing parsing of 
    virtually unlimited sized files with a strictly bounded RAM footprint.
    """
    k = target_reads
    
    # Standardize handling of 1 or 2 files
    in_handles = []
    out_handles = []
    
    for f in input_files:
        if f.endswith('.gz'):
            in_handles.append(gzip.open(f, 'rt'))
        else:
            in_handles.append(open(f, 'r'))
            
    for f in output_files:
        if f.endswith('.gz'):
            out_handles.append(gzip.open(f, 'wt'))
        else:
            out_handles.append(open(f, 'w'))

    reservoir = [[] for _ in range(len(input_files))]
    
    try:
        current_read = 0
        while True:
            # Read 4 lines at a time per file (1 FASTQ record)
            record_blocks = []
            for handle in in_handles:
                line1 = handle.readline()
                if not line1:
                    break
                line2 = handle.readline()
                line3 = handle.readline()
                line4 = handle.readline()
                record_blocks.append(line1 + line2 + line3 + line4)
                
            # If any handle hit EOF unexpectedly or elegantly, break.
            if len(record_blocks) < len(input_files) or not line1:
                break
                
            if current_read < k:
                for i in range(len(input_files)):
                    reservoir[i].append(record_blocks[i])
            else:
                j = random.randint(0, current_read)
                if j < k:
                    for i in range(len(input_files)):
                        reservoir[i][j] = record_blocks[i]
            
            current_read += 1
            if current_read % 100000 == 0:
                print(f"Scanned {current_read} reads...")
                
        # Write the reservoir to disk
        print(f"\nCompleted traversal! Found {current_read} total reads. Writing output files...")
        for i in range(len(output_files)):
            for record in reservoir[i]:
                out_handles[i].write(record)
                
        print(f"Success! Sampled {min(k, current_read)} sequences.")

    finally:
        for h in in_handles:
            h.close()
        for h in out_handles:
            h.close()

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="High Performance FASTQ Paired/Single Downsampler")
    parser.add_argument('-i', '--input', nargs='+', required=True, help="Input FASTQ files (1 for single, 2 for exactly matched pairs).")
    parser.add_argument('-o', '--output', nargs='+', required=True, help="Output destination for sampled FASTQ files.")
    parser.add_argument('-n', '--num_reads', type=int, required=True, help="Exact target number of sequence reads to retain.")
    
    args = parser.parse_args()
    
    if len(args.input) != len(args.output):
        sys.exit("Error: Mismatched number of inputs and output files provided.")
        
    print(f"Initializing Reservoir Engine... Target: {args.num_reads} reads per file.")
    reservoir_sample_fastq(args.input, args.output, args.num_reads)
