9#include <condition_variable>
31 using Batch = std::vector<Pair>;
46 size_t num_workers = 4,
size_t prefetch_batches = 2);
77 : loader_(other.loader_), batch_idx_(other.batch_idx_),
78 current_batch_(std::nullopt) {}
87 loader_ = other.loader_;
88 batch_idx_ = other.batch_idx_;
89 current_batch_ = std::nullopt;
164 mutable std::optional<Batch> current_batch_;
236 void worker_thread();
246 Batch load_batch(
size_t batch_idx);
259 std::optional<Batch> get_batch(
size_t batch_idx);
297 size_t prefetch_batches_;
303 size_t dataset_size_;
317 std::vector<size_t> indices_;
325 std::vector<std::thread> workers_;
334 std::queue<std::pair<size_t, Batch>> batch_queue_;
342 std::mutex queue_mutex_;
349 std::condition_variable queue_cv_;
357 std::condition_variable ready_cv_;
366 std::atomic<bool> stop_workers_;
375 std::atomic<bool> epoch_finished_;
383 size_t current_batch_idx_;
Iterator(const Iterator &other)
Iterator & operator=(Iterator &&)=default
Iterator(Iterator &&)=default
Iterator(Dataloader *loader, size_t batch_idx)
bool operator!=(const Iterator &other) const
bool operator==(const Iterator &other) const
std::ptrdiff_t difference_type
std::input_iterator_tag iterator_category
Iterator & operator=(const Iterator &other)
Iterator operator++(int)=delete
BatchView operator*() const
size_t get_prefetch_batches() const
void set_prefetch_batches(size_t prefetch_batches)
Dataloader(IDataset &dataset, size_t batch_size, bool shuffle=false, size_t num_workers=4, size_t prefetch_batches=2)
void set_num_workers(size_t num_workers)
std::span< Pair > BatchView
size_t get_num_workers() const
std::vector< Pair > Batch
std::pair< std::unique_ptr< IData >, size_t > Pair