#include <stdio.h>
#include <pthread.h>
#include "jthread.h"

static int ndying_threads = -1;
static pthread_t jt_dying_thread;
static pthread_mutex_t jt_lock;
static pthread_cond_t jt_cond;
static pthread_cond_t jt_die_block;
static int jt_nthreads = 0;

static void *jthread_garbage_collector(void *arg)
{
  Dllist d;
  unsigned int tid;
  void *dummy;

  pthread_mutex_lock(&jt_lock);
  while(1) {
    while (!dll_empty(jt_dying_threads)) {
      d = dll_first(jt_dying_threads);
      tid = d->val.ui;
      pthread_join(tid, &dummy);
      dll_delete_node(d);
      jt_nthreads--;
    }
    if (jt_nthreads == 0) pthread_exit(NULL);
    pthread_cond_wait(&jt_cond, &jt_lock);
  }
  pthread_mutex_unlock(&jt_lock);   /* This never gets executed... */
  return NULL;
}

void jthread_create(void * (*func)(void *), void *arg)
{
  pthread_t tid;

  if (ndying_threads == -1) {
    pthread_mutex_init(&jt_lock, NULL);
    pthread_mutex_lock(&jt_lock);
    ndying_threads = 0;
    pthread_cond_init(&jt_cond, NULL);
    pthread_cond_init(&jt_dieblock, NULL);
    jt_nthreads = 1;
    pthread_create(&tid, NULL, jthread_garbage_collector, NULL);
    pthread_mutex_unlock(&jt_lock);
  } else {
    pthread_mutex_lock(&jt_lock);
    jt_nthreads++;
    pthread_mutex_unlock(&jt_lock);
  }

  (*func)(arg);
  jthread_exit();
}

void jthread_exit()
{
  pthread_mutex_lock(&jt_lock);
  while(ndying_threads > 0) {
    pthread_cond_wait(&jt_die_block, &jt_lock);
  }
  
  pthread_cond_signal(&jt_cond);
  pthread_mutex_unlock(&jt_lock);
  pthread_exit(NULL);
}

