Run cleanup function in multiple Haskell child threads when POSIX Signal sent (SIGTERM etc)

79 views Asked by At

TL;DR - how do I make the following work in Haskell:

  • Send a SIGTERM to a long-running program with many active threads (each working on a job)
  • Get all child threads run a cleanup function (updating the database to say the job has aborted), before they exit

To my (very inexperienced) mind, it seems like the cleanest way to make this happen is to trap the SIGTERM in the 'main' thread, raise asynchronous exceptions in the child threads, and then use bracket in the child threads to react to the asynchronous exception by running some cleanup code. Empirically I cannot make this work.


More colour:

I have a Haskell program that spawns a number of threads to do work (using async). Basically, it:

  • Waits on notifications from the database job queue for new jobs
  • Spawns a new thread to do the work in
    • The thread update the job status in the database as it progresses (e.g. running, paused)
    • If the job completes, the user cancels the job, or a synchronous exception occurs, it updates the database with the final state (e.g. completed, cancelled, aborted)

Crucially, the main thread runs forever, just listening for new jobs, unless interrupted by SIGINT, SIGTERM, SIGKILL etc.

When the program gets a SIGINT or SIGTERM, I want to run some cleanup (namely, updating the database to set the status of in-flight jobs to aborted) before the 'child' thread dies. However, I absolutely cannot figure out how to do this.

My understanding is that for handling exceptions thrown from a 'parent' thread to a 'child' thread is to use bracket, which masks out the async exception for the main body of work , allowing you to run a cleanup function prior to terminating.

However, bracket doesn't seem to interact well with signal handling. I have tried installing signal handlers, to try to get the SIGTERM converted into a runtime exception that I can handle properly. This works great for the thread in which I installed the handler, but I can't throw an asynchronous exception to other threads, I think because they've also received the SIGTERM, and so they just die immediately.

It also appears that I can't install an individual SIGTERM per-thread, because it looks like the runtime can only have one signal handler per interrupt type across all threads (basically, if I do this, the last thread to start gets the interrupt, but all other threads, including the main thread, keep running).


Edited to add

Here's some example code that I've developed based on @Li-yao Xia's answer (which was super helpful - thank you).

One piece of the puzzle is that I'm creating child threads inside a recursive function (which listens to notifications on a job queue, then potentially spawns new workers in response). However, I can't see how to get pass the list of child threads to the exception handler, unless I attach the handler to every call of the recursive function (see example code below). However, this means that I'm not able to use effectively use tail-call recursion, and if I terminate the program, the exception handler gets called once for every loop as the stack frame unwinds. Is there a better pattern to make this work?

import Control.Concurrent.Async (Async, async, cancelWith)

import Control.Exception (AsyncException (..), Exception, SomeException, catch, throwIO)
import System.Posix (Signal)
import System.Posix.Signals (Handler (..), installHandler, sigHUP, sigINT, sigTERM, sigUSR1, sigUSR2, sigXCPU, sigXFSZ)

import Control.Concurrent (myThreadId, threadDelay, throwTo)
import Control.Monad (forM_)
import Data.Data (Typeable)

import Data.Foldable (for_)

data Result = Done | Aborted deriving (Show)

termMsg :: Int -> Result -> IO ()
termMsg n s = putStrLn $ "Thread " ++ show n ++ " terminated with " ++ show s

thread :: Int -> IO ()
thread n = job `catch` asyncHandler
  where
    job = do
      for_ ([0 .. 9] :: [Int]) $ \_ -> do
        putStrLn $ "Thread " ++ show n ++ " alive..."
        threadDelay $ 5e5 * n
      termMsg n Done

    asyncHandler :: AsyncException -> IO ()
    asyncHandler _ = do
      termMsg n Aborted

parent :: IO ()
parent = run 0 []
  where
    run :: Int -> [Async ()] -> IO ()
    run n w = do
      ( do
          putStrLn $ "Main thread alive (loop " ++ show n ++ ")"
          t <- async (thread n)
          let nw = w ++ [t]
          threadDelay 1e6
          run (n + 1) nw
        )
        `catch` handler w

    handler :: [Async ()] -> SomeException -> IO ()
    handler children e = do
      print e
      cleanupChildren children
      throwIO e

main :: IO ()
main = do
  installSignalHandlers
  parent `catch` someExceptionHandler

cleanupChildren :: [Async ()] -> IO ()
cleanupChildren children = do
  putStrLn "Cleaning up children..."
  for_ children $ \t -> cancelWith t ThreadKilled

someExceptionHandler :: SomeException -> IO ()
someExceptionHandler e = do
  putStrLn $ "Terminating with " ++ show e
  throwIO e

data SignalException = SignalException Signal String
  deriving (Show, Typeable, Eq)
instance Exception SignalException

signalsToHandle :: [(Signal, String)]
signalsToHandle = [(sigHUP, "SIGHUP"), (sigINT, "SIGINT"), (sigTERM, "SIGTERM"), (sigUSR1, "SIGUSR1"), (sigUSR2, "SIGUSR2"), (sigXCPU, "SIGXCPU"), (sigXFSZ, "SIGXFSZ")]

installSignalHandlers :: IO ()
installSignalHandlers = do
  mainId <- myThreadId
  forM_ signalsToHandle $ \(sig, name) -> installHandler sig (Catch (throwTo mainId $ SignalException sig name)) Nothing

Result:

Main thread alive (loop 0)
Thread 0 alive...

...

Main thread alive (loop 7)
Thread 7 alive...
Thread 3 alive...
Thread 5 alive...
Thread 4 alive...
Thread 2 alive...
Main thread alive (loop 8)
Thread 8 alive...
^CSignalException 2 "SIGINT"
Cleaning up children...
Thread 2 terminated with Aborted
Thread 3 terminated with Aborted
Thread 4 terminated with Aborted
Thread 5 terminated with Aborted
Thread 6 terminated with Aborted
Thread 7 terminated with Aborted
SignalException 2 "SIGINT"
Cleaning up children...
SignalException 2 "SIGINT"
Cleaning up children...
SignalException 2 "SIGINT"
Cleaning up children...
SignalException 2 "SIGINT"
Cleaning up children...
SignalException 2 "SIGINT"
Cleaning up children...
SignalException 2 "SIGINT"
Cleaning up children...
SignalException 2 "SIGINT"
Cleaning up children...
SignalException 2 "SIGINT"
Cleaning up children...
Terminating with SignalException 2 "SIGINT"
2

There are 2 answers

6
Li-yao Xia On BEST ANSWER

Here is a small example.

You may want to post your own minimal example to help diagnose your particular issue.

Could it be that your main thread is terminating without waiting for its children?

import Control.Concurrent (threadDelay, myThreadId)
import Control.Concurrent.Async
import Control.Exception
import Data.Foldable (for_)
import Data.Traversable (for)
import System.Posix.Signals (installHandler, sigTERM, Handler(..))

data Result = Done | Aborted deriving Show

thread :: IO Result
thread = job `catch` handler
  where
    job = do
      threadDelay 5000000
      pure Done
    handler AsyncCancelled = do
      -- additional clean up can be done here
      pure Aborted

main :: IO ()
main = do
  -- install handler for SIGTERM: throw UserInterrupt to main thread
  -- (SIGINT is already installed by default)
  mainId <- myThreadId
  installHandler sigTERM (Catch (throwTo mainId UserInterrupt)) Nothing

  -- spawn threads
  children <- for [0..9] $ \_ ->
    async thread

  -- wait for threads to terminate
  let waitAll = do
        for_ children $ \ t -> do
          wait t
          pure ()
        putStrLn "Normal termination"
  waitAll `catch` \e -> case e of
    UserInterrupt -> do
      putStrLn "Killed."
      putStrLn "Cleaning up..."
      for_ children $ \ t ->
        cancel t
      putStrLn "Waiting on children"
      results <- for children $ \ t ->
        wait t
      putStrLn ("Job results: " ++ show results)
    e -> throwIO e

Output after SIGINT or SIGTERM:

^CKilled.
Cleaning up...
Waiting on children
Job results: [Aborted,Aborted,Aborted,Aborted,Aborted,Aborted,Aborted,Aborted,Aborted,Aborted]

Update

Your modified example creates threads dynamically (you don't know ahead of time how many threads you will need) in a loop that is meant to listen for jobs. That complicates the structure of the program a little. Below is a fixed version.

  1. In the parent thread, try listen waits for either a job or an async exception. We pattern-match on the result outside of the exception handler try; we either keep looping without growing the stack (run is tail-recursive), or get an exception and clean up the children threads.
  2. mask_ makes it so that only the listen part may raise an async exception. (Here it's not necessary to restore the mask on listen because it uses a blocking operation, which already unmasks async exceptions.)
  3. Make sure your clean up actually waits for the children.
import Control.Concurrent (newChan, writeChan, readChan)
import Control.Concurrent.Async (Async, async, withAsync, asyncWithUnmask, wait, cancelWith)
import Control.Exception (AsyncException (..), Exception, SomeException, catch, try, throwIO, mask_)
import System.Posix (Signal)
import System.Posix.Signals (Handler (..), installHandler, sigHUP, sigINT, sigTERM, sigUSR1, sigUSR2, sigXCPU, sigXFSZ)
import Control.Concurrent (myThreadId, threadDelay, throwTo)
import Control.Monad (forM_)
import Data.Data (Typeable)
import Data.Foldable (for_)

data Result = Done | Aborted deriving (Show)

termMsg :: Int -> Result -> IO ()
termMsg n s = putStrLn $ "Thread " ++ show n ++ " terminated with " ++ show s

data Job = Job Int
  deriving Show

thread :: Job -> IO ()
thread (Job n) = job `catch` asyncHandler
  where
    job = do
      for_ ([0 .. 9] :: [Int]) $ \_ -> do
        putStrLn $ "Thread " ++ show n ++ " alive..."
        threadDelay $ 5000000 * n
      termMsg n Done

    asyncHandler :: AsyncException -> IO ()
    asyncHandler _ = do
      -- cleanup
      termMsg n Aborted

parent :: IO Job -> IO ()
parent listen = mask_ $ run []
  where
    run :: [Async ()] -> IO ()
    run w = do
      event <- try listen
      putStrLn $ "Main thread alive (loop " ++ show event ++ ")"
      case event :: Either SignalException Job of
        Right job -> do
          t <- asyncWithUnmask $ \unmask -> unmask (thread job)
          let nw = w ++ [t]
          run nw
        Left e -> do
          cleanupChildren w
          throwIO e

main :: IO ()
main = do
  installSignalHandlers
  -- `parent` waits for jobs by calling its `IO job` argument.
  -- We implement it here by reading from a channel which gets populated by the createJobs thread below.
  chan <- newChan
  let listen = readChan chan
      createJobs = for_ [1 .. 9] $ \i -> do
        threadDelay 1000000
        writeChan chan (Job i)
  withAsync createJobs $ \_ ->
    parent listen

cleanupChildren :: [Async ()] -> IO ()
cleanupChildren children = do
  putStrLn "Cleaning up children..."
  for_ children $ \t -> cancelWith t ThreadKilled
  -- Wait for the children to terminate their own cleanup
  for_ children $ \t -> wait t >> pure ()

data SignalException = SignalException Signal String
  deriving (Show, Typeable, Eq)
instance Exception SignalException

signalsToHandle :: [(Signal, String)]
signalsToHandle = [(sigHUP, "SIGHUP"), (sigINT, "SIGINT"), (sigTERM, "SIGTERM"), (sigUSR1, "SIGUSR1"), (sigUSR2, "SIGUSR2"), (sigXCPU, "SIGXCPU"), (sigXFSZ, "SIGXFSZ")]

installSignalHandlers :: IO ()
installSignalHandlers = do
  mainId <- myThreadId
  forM_ signalsToHandle $ \(sig, name) -> installHandler sig (Catch (throwTo mainId $ SignalException sig name)) Nothing
0
colophonemes On

OK, I think that I've found one possible answer, which is basically just to get STM involved:

import Control.Concurrent.Async (Async (asyncThreadId), async, cancelWith)

import Control.Exception (AsyncException (..), Exception, SomeException, catch, throwIO)
import System.Posix (Signal)
import System.Posix.Signals (Handler (..), installHandler, sigHUP, sigINT, sigTERM, sigUSR1, sigUSR2, sigXCPU, sigXFSZ)

import Control.Concurrent (myThreadId, threadDelay, throwTo)
import Control.Concurrent.STM (TVar, newTVarIO, readTVarIO, writeTVar)
import Control.Monad (forM_)
import Data.Array.MArray ()
import Data.Data (Typeable)

import Data.Foldable (for_)
import GHC.Conc (atomically)

data Result = Done | Aborted deriving (Show)

type ThreadList = [Async ()]
type GlobalThreadList = TVar ThreadList

main :: IO ()
main = do
  installSignalHandlers
  threads <- newTVarIO []
  parent threads `catch` parentExceptionHandler

parent :: GlobalThreadList -> IO ()
parent threads = do
  run 0 `catch` handler
  where
    run :: Int -> IO ()
    run n = do
      putStrLn $ "Main thread alive (loop " ++ show n ++ ")"
      async (thread n threads) >>= register
      threadDelay 1e6
      run (n + 1)

    register :: Async () -> IO ()
    register t = do
      tx <- readTVarIO threads
      atomically $ writeTVar threads (tx ++ [t])

    handler :: SomeException -> IO ()
    handler e = do
      print e
      children <- readTVarIO threads
      cleanupChildren children
      throwIO e

cleanupChildren :: [Async ()] -> IO ()
cleanupChildren children = do
  putStrLn "Cleaning up children..."
  for_ children $ \t -> cancelWith t ThreadKilled

termMsg :: Int -> Result -> IO ()
termMsg n s = putStrLn $ "Thread " ++ show n ++ " terminated with " ++ show s

thread :: Int -> GlobalThreadList -> IO ()
thread n threads = job `catch` asyncHandler
  where
    job = do
      for_ ([0 .. 9] :: [Int]) $ \_ -> do
        putStrLn $ "Thread " ++ show n ++ " alive..."
        threadDelay $ 5e5 * n
      termMsg n Done
      cleanup

    cleanup :: IO ()
    cleanup = do
      ts <- readTVarIO threads
      myTid <- myThreadId
      let newTs = [x | x <- ts, asyncThreadId x /= myTid]
      atomically $ writeTVar threads newTs

    asyncHandler :: AsyncException -> IO ()
    asyncHandler _ = do
      termMsg n Aborted

parentExceptionHandler :: SomeException -> IO ()
parentExceptionHandler e = do
  putStrLn $ "Terminating with " ++ show e
  throwIO e

data SignalException = SignalException Signal String
  deriving (Show, Typeable, Eq)
instance Exception SignalException

signalsToHandle :: [(Signal, String)]
signalsToHandle = [(sigHUP, "SIGHUP"), (sigINT, "SIGINT"), (sigTERM, "SIGTERM"), (sigUSR1, "SIGUSR1"), (sigUSR2, "SIGUSR2"), (sigXCPU, "SIGXCPU"), (sigXFSZ, "SIGXFSZ")]

installSignalHandlers :: IO ()
installSignalHandlers = do
  mainId <- myThreadId
  forM_ signalsToHandle $ \(sig, name) -> installHandler sig (Catch (throwTo mainId $ SignalException sig name)) Nothing

Result:

Main thread alive (loop 0)
Thread 0 alive...
Thread 0 alive...
Thread 0 alive...
Thread 0 alive...
Thread 0 alive...
Thread 0 alive...
Thread 0 alive...
Thread 0 alive...
Thread 0 alive...
Thread 0 alive...
Thread 0 terminated with Done
Main thread alive (loop 1)
Thread 1 alive...
Thread 1 alive...
Main thread alive (loop 2)
Thread 2 alive...
Thread 1 alive...
Thread 1 alive...
Main thread alive (loop 3)
Thread 2 alive...
Thread 3 alive...
Thread 1 alive...
Thread 1 alive...
Main thread alive (loop 4)
Thread 2 alive...
Thread 4 alive...
Thread 1 alive...
Thread 3 alive...
Thread 1 alive...
Thread 2 alive...
Main thread alive (loop 5)
Thread 5 alive...
Thread 1 alive...
^CSignalException 2 "SIGINT"
Cleaning up children...
Thread 1 terminated with Aborted
Thread 2 terminated with Aborted
Thread 3 terminated with Aborted
Thread 4 terminated with Aborted
Thread 5 terminated with Aborted
Terminating with SignalException 2 "SIGINT"