I've run into a problem that pytorch (tested with 2.0.1+cu117) does not fail gracefully when CPU OOM occurs. Specifically, I lose all ssh connections and Xserver access to the VM or bare metal machine.
I've not tested if this occurs on any other os.
the only solution I've found is to directly reboot the machine (through vsphere or just a power button).
I've considered aliasing python (accounting for conda env switching) to add each process to a cgroup which directly limits memory usage, but I've been advised that messing with cgroups is broadly a bad idea.
It is hard to judge how much memory a model is going to take, and I need a graceful way to kill it without it killing my ssh-server.
Some notes: I've tried this on two devices one VM one bare metal, both are Ubuntu 22.04, and I believe both have the OS-level OOM killer enabled. This is not reproducible with GPU OOM as that will actually kill the process and returns a typical GPU OOM error, but rather only with RAM.
I've tried setting the RLIMIT as described here: https://www.geeksforgeeks.org/python-how-to-put-limits-on-memory-and-cpu-usage/. Though that hasn't solved my problem.
I've also considered just adding a condition in my training that checks available memory and breaks, but this seems like a unclean solution.
I haven't tried cgroups because as above mentioned. I'm also only tangentially familiar with OS things and I hesitate to do anything with it lest I break something I cannot fix.
#https://www.geeksforgeeks.org/python-how-to-put-limits-on-memory-and-cpu-usage/
^ I implemented this incorrectly prior (misspelling on the restricted attribute), but this in fact does solve the problem and prevents overuse of RAM.
https://github.com/daniel-redder/mem_restrict/blob/main/mem_restrict.py ^ import to constrain
(I will continue discussing the OOM killer stuff here because it is interesting, but I want to put up the easy solution)