How to detect and characterize when a Celery worker dies due to OOM

41 views Asked by At

I'm using Celery to execute some (potentially) memory intensive tasks and I need a way to identify which tasks fail due to out-of-memory (OOM) exceptions. Specifically I'd like to record, in an external database, which tasks have failed for reporting and analysis purposes.

I have not been able to figure out what information is available both within the task function body and in the on_failure method of the Request object that can be used to link up what parameters passed to the Task (arg1, arg2 in the code below) caused the failure. The approach I have in mind requires there to be some sort of "key" -- a piece of information that is accessible both in the task and in the Request.on_failure method. The process would be:

  1. When task starts: store a record that includes the arg1, arg2 information, for the task in a database for this task; a record that can be recovered from this "key" that identifies the stat
  2. When the task dies: use the "key" to update the record and indicate that the error occurred.

Note I do not need or want to re-try the task, I just need to record which task has died for later assessment.

I'm running the Celery workers within a Kubernetes cluster, so they have strict memory constraints.

import celery                                                                   
import celery.worker.request
import logging
import numpy

logger = logging.getLogger()
                                                
class _CatchingRequest(celery.worker.request.Request):                             
                                                                                                                                                                
    def on_failure(self, exc_info, send_failed_event=True, return_ok=False):    

        # What I want here is some way to know 
        # which task died 


        logger.error("In _LayerRequest.on_failure; pid %s", os.getpid())        
        logger.error("exc_info: %s [%s]", exc_info, type(exc_info))             
        logger.error("dir(exc_info): %s", dir(exc_info))                        
        logger.error("traceback: %s %s", type(exc_info.traceback), exc_info.traceback)
        tb = exc_info.traceback                                                 
        patt = re.compile(r"SIGKILL.*Job\:\s(\d+)\.")                           
        if m := patt.search(tb):                                                
            job = int(m.groups()[0])                                            
            logger.error("pattern matches %d", job)                     
        return super().on_failure(exc_info, send_failed_event, return_ok)       
                                                                                
                                                                                
                                                                                
class _CatchingTask(celery.Task):                                                  
    """    
         Just a stub so we can actually handle the 
         failure                                                                     
    """                                                                         
    Request = _CatchingRequest


@agent.task(ignore_result=True, acks_late=True, base=_CatchingTask, bind=True)     
def api_create_layer(self, arg1, arg2, timestamp):               
    logger.error("enter create layer; job id: %s", self.request.id)             

    # What I want here is some kind of value/identifier that
    # that can be obtained in the on_failure call

                                                                                
    try:                                                                        
        _memory_intensive(arg1, arg2, timestamp)                           
    except Exception as err:                                                    
        logger.error("Caught generic memory intensive exception")                                                                   
        logger.error(str(err))
        # any other record keeping



def _memory_intensive(arg1, arg2, timestamp):
    # just illustrative of something that could easily use a lot of memroy
    arr = numpy.zeros([arg1, arg2, arg1, arg2], dtype=numpy.float64)
    numpy.save(f"zeros-{arg1}-{arg2}-{timestamp}.npy", arr)

                          
0

There are 0 answers