#include <stdio.h>
#include <pthread.h>
#include "dllist.h"
#include "jthread.h"

static Dllist jt_dying_threads = NULL;
static pthread_mutex_t jt_lock;
static pthread_cond_t jt_cond;
static int jt_nthreads = 0;

static void *jthread_garbage_collector(void *arg)
{
  Dllist d;
  unsigned int tid;
  void *dummy;

  pthread_mutex_lock(&jt_lock);
  while(1) {
    while (!dll_empty(jt_dying_threads)) {
      d = dll_first(jt_dying_threads);
      tid = d->val.ui;
      pthread_join(tid, &dummy);
      dll_delete_node(d);
      jt_nthreads--;
    }
    if (jt_nthreads == 0) pthread_exit(NULL);
    pthread_cond_wait(&jt_cond, &jt_lock);
  }
  pthread_mutex_unlock(&jt_lock);   /* This never gets executed... */
  return NULL;
}

typedef struct {
  void (*func)(void *);
  void *arg;
} Closure;

void *jthread_starter(void *arg)
{
  Closure *c;
  void (*f)(void *);
  void *a;

  c = (Closure *) arg;
  f = c->func;
  a = c->arg;
  free(c);
  (*f)(a);
  jthread_exit();
  return NULL;
}

void jthread_system_init()
{
  pthread_t tid;

  if (jt_dying_threads == NULL) {
    jt_dying_threads = new_dllist();
    pthread_mutex_init(&jt_lock, NULL);
    pthread_mutex_lock(&jt_lock);
    pthread_cond_init(&jt_cond, NULL);
    jt_nthreads = 1;
    pthread_create(&tid, NULL, jthread_garbage_collector, NULL);
    pthread_mutex_unlock(&jt_lock);
  } else {
    fprintf(stderr, "Jthread_system_init called twice\n");
    exit(1);
  }
}

int jthread_create(void (*func)(void *), void *arg)
{
  pthread_t tid;
  Closure *c;

  c = (Closure *) malloc(sizeof(Closure));
  c->func = func;
  c->arg = arg;

  if (jt_dying_threads == NULL) {
    fprintf(stderr, "Jthread_system_init never called\n");
    exit(1);
  }
  pthread_mutex_lock(&jt_lock);
  jt_nthreads++;
  pthread_mutex_unlock(&jt_lock);

  return pthread_create(&tid, NULL, jthread_starter, (void *) c);
}

void jthread_exit()
{
  if (jt_dying_threads == NULL) {
    fprintf(stderr, "Jthread_system_init never called\n");
    exit(1);
  }
  pthread_mutex_lock(&jt_lock);
  dll_append(jt_dying_threads, new_jval_ui(pthread_self()));
  pthread_cond_signal(&jt_cond);
  pthread_mutex_unlock(&jt_lock);
  pthread_exit(NULL);
}

