diff --git a/src/extract_rdcosts.py b/src/extract_rdcosts.py index fbcb8479..33938abf 100755 --- a/src/extract_rdcosts.py +++ b/src/extract_rdcosts.py @@ -123,7 +123,6 @@ def run_job(job): with open(logpath, "w") as lf: with MultiPipeGZOutManager(odpath, dest_qps) as pipes_and_outputs: - gzips = [] gzip_threads = [] for pipe_fn, out_fn in pipes_and_outputs.items(): gzip_thread = threading.Thread(target=do_gzip, args=(pipe_fn, out_fn)) @@ -131,10 +130,7 @@ def run_job(job): gzip_threads.append(gzip_thread) kvz = subprocess.Popen(my_kvzargs, env=kvzenv, stderr=lf) - - kvz.communicate() - for gzip in gzips: - gzip.communicate() + kvz.wait() def threadfunc(joblist): for job in joblist: diff --git a/src/filter_rdcosts.c b/src/filter_rdcosts.c index 0723dcef..ef260a1c 100644 --- a/src/filter_rdcosts.c +++ b/src/filter_rdcosts.c @@ -121,6 +121,7 @@ int process_rdcosts(FILE *in, FILE *out) } printf("\n"); } + fflush(stdout); out: free(buf); diff --git a/src/run_filter.py b/src/run_filter.py index 784533ee..b685c218 100755 --- a/src/run_filter.py +++ b/src/run_filter.py @@ -1,20 +1,56 @@ #!/usr/bin/env python3 import glob +import gzip import os +import re import subprocess +import sys import tempfile import threading import time n_threads = 8 -data = "/home/moptim/rdcost/data/*.gz" +datadirs = "/tmp/rdcost/data/" +#datadirs = "/tmp/rdcost/data/RaceHorses_416x240_30.yuv-qp23/" gzargs = ["gzip", "-d"] filtargs = ["./frcosts_matrix"] octargs = ["octave-cli", "invert_matrix.m"] filt2args = ["./ols_2ndpart"] resultdir = os.path.join("/tmp", "rdcost", "coeff_buckets") +gz_glob = "[0-9][0-9].txt.gz" + +class MultiPipeManager: + pipe_fn_template = "%02i.txt" + + def __init__(self, odpath, dest_qps): + self.odpath = odpath + self.dest_qps = dest_qps + + self.pipe_fns = [] + for qp in dest_qps: + pipe_fn = os.path.join(self.odpath, self.pipe_fn_template % qp) + self.pipe_fns.append(pipe_fn) + + def __enter__(self): + os.makedirs(self.odpath, exist_ok=True) + for pipe_fn in self.pipe_fns: + try: + os.unlink(pipe_fn) + except FileNotFoundError: + pass + os.mkfifo(pipe_fn) + return self + + def __exit__(self, *_): + for pipe_fn in self.pipe_fns: + os.unlink(pipe_fn) + + def items(self): + for pipe_fn in self.pipe_fns: + yield pipe_fn + class MTSafeIterable: def __init__(self, iterable): self.lock = threading.Lock() @@ -27,40 +63,82 @@ class MTSafeIterable: with self.lock: return next(self.iterable) -def run_job(job): - datapath = job - resultpath = os.path.join(resultdir, os.path.basename(job) + ".result") +def read_in_blocks(f): + BLOCK_SZ = 65536 + while True: + block = f.read(BLOCK_SZ) + if (len(block) == 0): + break + else: + yield block - print("Running job %s" % datapath) +def exhaust_gzs(sink_f, gzs): + for gz in gzs: + with gzip.open(gz, "rb") as f: + if (gz == "/tmp/rdcost/data/RaceHorses_416x240_30.yuv-qp22/20.txt.gz"): + print("kjeh") + print(" Doing %s ..." % gz) + for block in read_in_blocks(f): + sink_f.write(block) + sink_f.flush() - with open(resultpath, "w") as rf: - with tempfile.NamedTemporaryFile() as tf: - with open(datapath, "rb") as df: - f2a = list(filt2args) - f2a.append(tf.name) - gzip = subprocess.Popen(gzargs, stdin=df, stdout=subprocess.PIPE) - filt = subprocess.Popen(filtargs, stdin=gzip.stdout, stdout=subprocess.PIPE) - octa = subprocess.Popen(octargs, stdin=filt.stdout, stdout=tf) +def run_job(jobname, input_gzs): + resultpath = os.path.join(resultdir, "%s.result" % jobname) + print("Running job %s" % jobname) - octa.communicate() - filt.communicate() - gzip.communicate() + with tempfile.NamedTemporaryFile() as tf: + filt = subprocess.Popen(filtargs, stdin=subprocess.PIPE, stdout=subprocess.PIPE) + octa = subprocess.Popen(octargs, stdin=filt.stdout, stdout=tf) - with open(datapath, "rb") as df: - gz2 = subprocess.Popen(gzargs, stdin=df, stdout=subprocess.PIPE) - f2 = subprocess.Popen(f2a, stdin=gz2.stdout, stdout=rf) + try: + exhaust_gzs(filt.stdin, input_gzs) + except OSError as e: + print("OSError %s" % e, file=sys.stderr) + raise - f2.communicate() - gz2.communicate() + filt.stdin.close() + filt.wait() + octa.wait() - print("Job %s done" % datapath) + if (filt.returncode != 0): + print("First stage failed: %s" % jobname, file=sys.stderr) + assert(0) + + with open(resultpath, "w") as rf: + f2a = filt2args + [tf.name] + f2 = subprocess.Popen(f2a, stdin=subprocess.PIPE, stdout=rf) + exhaust_gzs(f2.stdin, input_gzs) + f2.communicate() + if (filt.returncode != 0): + print("Second stage failed: %s" % jobname, file=sys.stderr) + assert(0) + + print("Job %s done" % jobname) def threadfunc(joblist): - for job in joblist: - run_job(job) + for jobname, job in joblist: + run_job(jobname, job) + +def scan_datadirs(path): + seq_names = set() + for dirent in os.scandir(path): + if (not dirent.is_dir()): + continue + match = re.search("^([A-Za-z0-9_]+\.yuv)-qp[0-9]{1,2}$", dirent.name) + if (not match is None): + seq_name = match.groups()[0] + seq_names.add(seq_name) + + for seq_name in seq_names: + seq_glob = os.path.join(path, seq_name + "-qp*/") + + for qp in range(51): + job_name = seq_name + "-qp%02i" % qp + qp_fn = "%02i.txt.gz" % qp + yield job_name, glob.glob(os.path.join(seq_glob, qp_fn)) def main(): - jobs = glob.glob(data) + jobs = scan_datadirs(datadirs) joblist = MTSafeIterable(iter(jobs)) threads = [threading.Thread(target=threadfunc, args=(joblist,)) for _ in range(n_threads)]