Pytorch CPU OOM kills ssh server on linux

364 views Asked by At

I've run into a problem that pytorch (tested with 2.0.1+cu117) does not fail gracefully when CPU OOM occurs. Specifically, I lose all ssh connections and Xserver access to the VM or bare metal machine.

I've not tested if this occurs on any other os.

the only solution I've found is to directly reboot the machine (through vsphere or just a power button).

I've considered aliasing python (accounting for conda env switching) to add each process to a cgroup which directly limits memory usage, but I've been advised that messing with cgroups is broadly a bad idea.

It is hard to judge how much memory a model is going to take, and I need a graceful way to kill it without it killing my ssh-server.

Some notes: I've tried this on two devices one VM one bare metal, both are Ubuntu 22.04, and I believe both have the OS-level OOM killer enabled. This is not reproducible with GPU OOM as that will actually kill the process and returns a typical GPU OOM error, but rather only with RAM.

I've tried setting the RLIMIT as described here: https://www.geeksforgeeks.org/python-how-to-put-limits-on-memory-and-cpu-usage/. Though that hasn't solved my problem.

I've also considered just adding a condition in my training that checks available memory and breaks, but this seems like a unclean solution.

I haven't tried cgroups because as above mentioned. I'm also only tangentially familiar with OS things and I hesitate to do anything with it lest I break something I cannot fix.

2

There are 2 answers

1
Daniel Redder On BEST ANSWER

#https://www.geeksforgeeks.org/python-how-to-put-limits-on-memory-and-cpu-usage/

^ I implemented this incorrectly prior (misspelling on the restricted attribute), but this in fact does solve the problem and prevents overuse of RAM.

https://github.com/daniel-redder/mem_restrict/blob/main/mem_restrict.py ^ import to constrain

import psutil
import resource


#https://www.geeksforgeeks.org/python-how-to-put-limits-on-memory-and-cpu-usage/

PERCENTAGE_MEMORY_ALLOWED = 0.8

# Calculate the maximum memory limit (80% of available memory)
virtual_memory = psutil.virtual_memory()
available_memory = virtual_memory.available
memory_limit = int(available_memory * PERCENTAGE_MEMORY_ALLOWED)

print(f'{memory_limit} memory limit, available: {available_memory}')

# Set the memory limit
resource.setrlimit(resource.RLIMIT_AS, (memory_limit, memory_limit))

(I will continue discussing the OOM killer stuff here because it is interesting, but I want to put up the easy solution)

4
maxy On

You can try tuning oom_score_adj using the choom command:

choom -n 1000 -- python yourscript.py

Quoting man proc about /proc/<pid>/oom_score_adj:

[..] adjust the badness heuristic used to select which process gets killed in out-of-memory conditions. [The badness value is] ranging from 0 (never kill) to 1000 (always kill) to determine which process is targeted.

If you are not sure if the OOM killer was active, you can check what happened in the kernel log with journalctl -ke. (If you had to reboot, add --boot -1.) It will look like this:

Out of memory: Killed process 68412 (python3) total-vm:70398288kB, anon-rss:61052944kB, file-rss:2304kB, shmem-rss:0kB, UID:1000 pgtables:119712kB oom_score_adj:1000

Further reading: