diff options
Diffstat (limited to 'fastnn/io/Example.cpp')
-rw-r--r-- | fastnn/io/Example.cpp | 186 |
1 files changed, 186 insertions, 0 deletions
diff --git a/fastnn/io/Example.cpp b/fastnn/io/Example.cpp new file mode 100644 index 0000000..8f200b7 --- /dev/null +++ b/fastnn/io/Example.cpp @@ -0,0 +1,186 @@ + +#include <deque> +#include <vector> +#include <string> + + +extern "C" { + +#include "../threads/lib/THThread.h" +#include "Example.h" +#include "stdlib.h" +#include "stdio.h" + +#include "../../nerv/lib/matrix/generic/elem_type.h" +#include "common.h" + +extern Matrix* nerv_matrix_cuda_float_create(long nrow, long ncol, Status *status); +void nerv_matrix_cuda_float_copy_fromd(Matrix *a, const Matrix *b, + int a_begin, int b_begin, int b_end, + Status *status); + +struct Example +{ + std::vector<Matrix *> inputs; + std::string id; + int refcount; +}; + +struct ExamplesRepository +{ + int buffer_size_; + THSemaphore *full_semaphore_; + THSemaphore *empty_semaphore_; + THMutex *examples_mutex_; + + std::deque<Example*> examples_; + bool done_; + int refcount; + int gpuid; +}; + +Example* Example_new() +{ + Example *example = new Example; //(Example*)malloc(sizeof(Example)); + example->refcount = 1; + return example; +} + +Example* Example_newWithId(long id) +{ + Example* example = (Example*)(id); + __sync_fetch_and_add(&example->refcount, 1); + return example; +} + +long Example_id(Example *example) +{ + return (long)(example); +} + +void Example_destroy(Example* example) +{ + //printf("Example_destroy: %d\n", example->inputs.size()); + if (NULL != example && __sync_fetch_and_add(&example->refcount, -1) == 1) + { + delete example; + example = NULL; + } +} + +int Example_size(Example* example) +{ + return example->inputs.size(); +} + +Matrix* Example_at(Example* example, int idx) +{ + return example->inputs.at(idx); +} + +void Example_pushback(Example* example, Matrix* m) +{ + Status status; + Matrix *newm = nerv_matrix_cuda_float_create(m->nrow, m->ncol, &status); + nerv_matrix_cuda_float_copy_fromd(newm, m, 0, 0, m->nrow, &status); + //__sync_fetch_and_add(&m->refcount, 1); + return example->inputs.push_back(newm); +} + +////////////////////////////////////////////////////////////// + + +ExamplesRepository* ExamplesRepository_new(int buffersize) +{ + ExamplesRepository *repo = new ExamplesRepository; //(ExamplesRepository*)malloc(sizeof(ExamplesRepository)); + repo->buffer_size_ = buffersize; + repo->full_semaphore_ = THSemaphore_new(0); + repo->empty_semaphore_ = THSemaphore_new(buffersize); + repo->examples_mutex_ = THMutex_new(); +// repo->examples_ = new std::deque<Example*>(); + repo->done_ = false; + repo->refcount = 1; + repo->gpuid = -1; + return repo; +} + +ExamplesRepository* ExamplesRepository_newWithId(long id) +{ + ExamplesRepository *repo = (ExamplesRepository*)(id); + __sync_fetch_and_add(&repo->refcount, 1); + return repo; +} + +long ExamplesRepository_id(ExamplesRepository* repo) +{ + return (long)(repo); +} + +int ExamplesRepository_getGpuId(ExamplesRepository* repo) +{ + return repo->gpuid; +} + +void ExamplesRepository_setGpuId(ExamplesRepository* repo, int gpuid) +{ + repo->gpuid = gpuid; +} + +void ExamplesRepository_destroy(ExamplesRepository* repo) +{ + if (NULL != repo && __sync_fetch_and_add(&repo->refcount, -1) == 1) + { + if (repo->full_semaphore_) + THSemaphore_free(repo->full_semaphore_); + if (repo->empty_semaphore_) + THSemaphore_free(repo->empty_semaphore_); + if (repo->examples_mutex_) + THMutex_free(repo->examples_mutex_); + delete repo; + repo = NULL; + } +} + +void AcceptExample(ExamplesRepository *repo, Example *example) +{ + THSemaphore_wait(repo->empty_semaphore_); + THMutex_lock(repo->examples_mutex_); + __sync_fetch_and_add(&example->refcount, 1); + repo->examples_.push_back(example); + THMutex_unlock(repo->examples_mutex_); + THSemaphore_signal(repo->full_semaphore_); +} + +void ExamplesDone(ExamplesRepository *repo) +{ + for (int i = 0; i < repo->buffer_size_; i++) + THSemaphore_wait(repo->empty_semaphore_); + + repo->done_ = true; + THSemaphore_signal(repo->full_semaphore_); +} + +Example* ProvideExample(ExamplesRepository *repo) +{ + Example *ans = NULL; + THSemaphore_wait(repo->full_semaphore_); + if (repo->done_) + { + THSemaphore_signal(repo->full_semaphore_); // Increment the semaphore so + // the call by the next thread will not block. + return NULL; // no examples to return-- all finished. + } + else + { + THMutex_lock(repo->examples_mutex_); + ans = repo->examples_.front(); + repo->examples_.pop_front(); + THMutex_unlock(repo->examples_mutex_); + THSemaphore_signal(repo->empty_semaphore_); + } + return ans; +} + +} + + |