Friday, January 6, 2012

Groovy/Grails : Integration tests for multi-threaded application

Recently I came across a situation in my project, where I had to write integration test for multi-threaded code. Basically I had a business method in service class which executed some portion of it's work in separate thread. Now in order to write integration test for such method, test method should ideally wait for the thread(s) created by service class to complete before asserting result. But there was no way I could figure out if the threads created by service class has completed. So then I thought to refactor my code in a way that I can figure from outside if the threads have got completed or not.


Accordingly I have created ThreadCompetionTracker as mentioned below.

package com.beachbody.pioneer.util

import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.CountDownLatch
import java.util.concurrent.TimeUnit
import java.util.concurrent.Executor
import java.util.concurrent.Callable
import java.util.concurrent.ExecutorService
import org.apache.commons.logging.LogFactory

class ThreadCompletionTracker {
    private static def log = LogFactory.getLog(this)

    private static ConcurrentHashMap<String, CountDownLatch> latchMap = new ConcurrentHashMap<String, CountDownLatch>()

    private static ThreadLocal<String> latchKey = new LatchKey()

    static void waitToComplete(Integer threadCount=1,Closure closure){
        log.info("Inside wait to complete")
        String key = latchKey.get()
        try{
            latchMap.put(key, new CountDownLatch(threadCount))
            closure.call()
            CountDownLatch countDownLatch = latchMap.get(latchKey.get())
            countDownLatch.await(30, TimeUnit.SECONDS)
        }finally {
            log.info("Removing key from latch map : "+key)
            latchMap.remove(latchKey.get())
        }

    }

    static void withTracking(ExecutorService executor, Closure closure){
        withTracking(executor,closure as Callable)
    }


    static void withTracking(ExecutorService executor, Callable callable){
        String key = latchKey.get()
        Closure task = {
            runCallable(key,callable)
        }
        executor.submit(task)
    }

    private static void runCallable(String key, Callable callable){
        log.info("Running callable inside runCallable")
        try{
            latchKey.set(key)
            callable.call()
        }finally {
            if(latchMap.containsKey(key)){
                log.info("latchMap contains key : ${key}")
                CountDownLatch countDownLatch = latchMap.get(latchKey.get())
                countDownLatch.countDown()
            }
        }
    }

}

class LatchKey extends ThreadLocal{
    protected synchronized String initialValue(){
        String name = Thread.currentThread().name+System.currentTimeMillis()
        return name
    }
}


This class has mainly two important public methods.
1. WaitToComplete - Will wait for the closure to complete.
2. WithTracking - Enables tracking for service.

Here is the test case written using this class.


void testSendEmail(){
        assertEquals(0,greenMail.getReceivedMessages().length)
        ThreadCompletionTracker.waitToComplete{
            EmailRequest request = new EmailRequest()
            //Write some code to populate EmailRequest object correctly
            mailSender.sendEmail(request)
        }
        assertEquals(1,greenMail.getReceivedMessages().length)
    }

The above integration test case, sends a mail using a mailSender service class. Within the service, mail is being sent via separate thread as shown below. Since the mailSender.sendEmail has been wrapped by "waitToComplete" method, current thread of integration test will wait until mail sending thread has completed.

class MailSender {

    static transactional = true
    def handlerMap
    def final pool = Executors.newFixedThreadPool(10)

    def sendEmail(EmailRequest emailRequest){
        emailRequest.save(flush:true)
        ThreadCompletionTracker.withTracking(pool, sendRequestToHandler(emailRequest) as Callable)
    }

    def sendRequestToHandler(EmailRequest emailRequest){
        try{
            handlerMap[emailRequest.type].handle(emailRequest)
        }catch(PioneerException pe){
            log.error(pe,pe)
        }
        return true
    }

    
}

Above is the MailSender service which sends email to handler for necessary processing in asynchronous way. It uses ThreadCompletionTracker's withTracking method. This way we can track if all the threads created by MailSender service have completed or not. Once the threads are completed then only we issue any assert statements in integration test. We also give default time to wait i.e. 30 seconds in waitToComplete method. This is necessary as we can not wait indefinitely for threads to complete.


No comments:

Post a Comment