Why do I get ConnectionResetError when reading and writing from and to s3 using smart_open?

1.3k views Asked by At

The following code can read and write back to s3 on the fly following the the discussion on here:

from smart_open import open
import os

bucket_dir = "s3://my-bucket/annotations/"

with open(os.path.join(bucket_dir, "in.tsv.gz"), "rb") as fin:
    with open(
        os.path.join(bucket_dir, "out.tsv.gz"), "wb"
    ) as fout:
        for line in fin:
            l = [i.strip() for i in line.decode().split("\t")]
            string = "\t".join(l) + "\n"
            fout.write(string.encode())    

The issue is that after a few thousands lines processed (a few minutes) I get a "connection reset by peer" error:

    raise ProtocolError("Connection broken: %r" % e, e)
urllib3.exceptions.ProtocolError: ("Connection broken: ConnectionResetError(104, 'Connection reset by peer')", ConnectionResetError(104, 'Connection reset by peer'))

What can I do? I tried to fout.flush() after every fout.write(string.encode()) but it doesn't work well. Is there a better solution to approach to process a .tsv file with about 200 million lines?

1

There are 1 answers

0
0x90 On BEST ANSWER

I implemented some producer-consumer approach on top of smart_open. This mitigates the Connection broke error, but doesn't resolve it completely in some cases.

class Producer:
    def __init__(self, queue, bucket_dir, input_file):
        self.queue = queue
        self.bucket_dir = bucket_dir
        self.input_file = input_file

    def run(self):
        with open(os.path.join(self.bucket_dir, self.input_file), "rb") as fin:
            for line in tqdm(fin):
                while self.queue.full():
                    time.sleep(0.05)
                self.queue.put(line_to_write)
        self.queue.put("DONE")


class Consumer:
    def __init__(self, queue, bucket_dir, output_file):
        self.queue = queue
        self.bucket_dir = bucket_dir
        self.output_file = output_file

    def run(self):
        done = False
        to_write = ""
        count = 0
        with open(os.path.join(self.bucket_dir, self.output_file), "wb") as fout:
            while True:
                while self.queue.empty():
                    time.sleep(0.05)
                item = self.queue.get()
                if item == "DONE":
                    fout.write(to_write)
                    fout.flush()
                    self.queue.task_done()
                    return

                count += 1
                to_write += item
                if count % 256 == 0:  # batch write
                    fout.write(to_write.encode())
                    fout.flush()


def main(args):
    q = Queue(1024)

    producer = Producer(q, args.bucket_dir, args.input_file)
    producer_thread = threading.Thread(target=producer.run)

    consumer = Consumer(q, args.bucket_dir, args.output_file)
    consumer_thread = threading.Thread(target=consumer.run)

    producer_thread.start()
    consumer_thread.start()

    producer_thread.join()
    consumer_thread.join()
    q.join()