ReUseX  0.0.1
3D Point Cloud Processing for Building Reuse
Loading...
Searching...
No Matches
Dataloader.hpp
Go to the documentation of this file.
1// SPDX-FileCopyrightText: 2026 Povl Filip Sonne-Frederiksen
2//
3// SPDX-License-Identifier: GPL-3.0-or-later
4#pragma once
7
8#include <atomic>
9#include <condition_variable>
10#include <memory>
11#include <mutex>
12#include <optional>
13#include <queue>
14#include <span>
15#include <thread>
16#include <vector>
17
18namespace ReUseX::vision {
19
20/* * Dataloader is a class that provides an iterable interface to a dataset.
21 * It loads batches of data from the dataset in a separate thread and provides
22 * them to the user when requested. It supports shuffling, multiple worker
23 * threads, and prefetching batches. The user can iterate over the dataloader
24 * using a range-based for loop or by manually creating an iterator. The
25 * dataloader will automatically stop the worker threads when the iteration is
26 * complete or when the dataloader is destroyed.
27 * */
29 public:
31 using Batch = std::vector<Pair>;
32 using BatchView = std::span<Pair>;
33
34 /* * Constructs a Dataloader for the given dataset with the specified batch
35 * size, shuffle option, number of worker threads, and number of prefetch
36 * batches.
37 * @param dataset The dataset to load data from.
38 * @param batch_size The number of samples in each batch.
39 * @param shuffle Whether to shuffle the dataset at the beginning of each
40 * epoch.
41 * @param num_workers The number of worker threads to use for loading batches.
42 * @param prefetch_batches The number of batches to prefetch in the
43 * background.
44 * */
45 Dataloader(IDataset &dataset, size_t batch_size, bool shuffle = false,
46 size_t num_workers = 4, size_t prefetch_batches = 2);
47
49
50 /* * Iterator is a class that provides an input iterator interface to the
51 * Dataloader. It allows the user to iterate over the batches of data in the
52 * dataloader using a range-based for loop or by manually creating an
53 * iterator. The iterator will automatically load batches from the dataloader
54 * as needed and will stop when the iteration is complete.
55 * */
56 class Iterator {
57 public:
58 using iterator_category = std::input_iterator_tag;
60 using difference_type = std::ptrdiff_t;
61 using pointer = Batch *;
62 using reference = Batch &;
63
64 /* * Constructs an iterator for the given dataloader and batch index.
65 * @param loader The dataloader to iterate over.
66 * @param batch_idx The index of the batch to start iterating from.
67 * */
68 Iterator(Dataloader *loader, size_t batch_idx);
69
70 /* * Copy constructor and copy assignment operator for the iterator. The
71 * current batch is not copied and will be reloaded when the iterator is
72 * dereferenced.
73 * @param other The iterator to copy from.
74 * @return A reference to the copied iterator.
75 * */
76 Iterator(const Iterator &other)
77 : loader_(other.loader_), batch_idx_(other.batch_idx_),
78 current_batch_(std::nullopt) {}
79
80 /* * Copy assignment operator for the iterator. The current batch is not
81 * copied and will be reloaded when the iterator is dereferenced.
82 * @param other The iterator to copy from.
83 * @return A reference to the copied iterator.
84 * */
85 Iterator &operator=(const Iterator &other) {
86 if (this != &other) {
87 loader_ = other.loader_;
88 batch_idx_ = other.batch_idx_;
89 current_batch_ = std::nullopt;
90 }
91 return *this;
92 }
93
94 /* * Move constructor and move assignment operator for the iterator. The
95 * current batch is not moved and will be reloaded when the iterator is
96 * dereferenced.
97 * @param other The iterator to move from.
98 * @return A reference to the moved iterator.
99 * */
100 Iterator(Iterator &&) = default;
101
102 /* * Move assignment operator for the iterator. The current batch is not
103 * moved and will be reloaded when the iterator is dereferenced.
104 * @param other The iterator to move from.
105 * @return A reference to the moved iterator.
106 * */
108
109 /* * Dereference operator for the iterator. It returns a view of the current
110 * batch of data. If the current batch is not loaded, it will be loaded
111 * from the dataloader.
112 * @return A view of the current batch of data.
113 * */
115
116 /* * Pre-increment operator for the iterator. It advances the iterator to
117 * the next batch of data. If the next batch is not loaded, it will be
118 * loaded from the dataloader.
119 * @return A reference to the advanced iterator.
120 * */
122
123 /* * Pre-increment operator for the iterator. It advances the iterator to
124 * the next batch of data. If the next batch is not loaded, it will be
125 * loaded from the dataloader.
126 * @return A reference to the advanced iterator.
127 * */
129
130 /* * Post-increment operator for the iterator. It is deleted to prevent
131 * inefficient copying of batches. Use the pre-increment operator instead.
132 * */
133 Iterator operator++(int) = delete;
134
135 /* * Equality operator for the iterator. It checks if two iterators are
136 * equal by comparing their dataloader pointers and batch indices.
137 * @param other The iterator to compare with.
138 * @return True if the iterators are equal, false otherwise.
139 * */
140 bool operator==(const Iterator &other) const;
141
142 /* * Inequality operator for the iterator. It checks if two iterators are
143 * not equal by comparing their dataloader pointers and batch indices.
144 * @param other The iterator to compare with.
145 * @return True if the iterators are not equal, false otherwise.
146 * */
147 bool operator!=(const Iterator &other) const;
148
149 private:
150 /* * The dataloader that this iterator belongs to. It is used to load
151 * batches of data when the iterator is dereferenced or advanced.
152 * */
153 Dataloader *loader_;
154 /* * The index of the current batch that this iterator points to. It is used
155 * to determine which batch to load from the dataloader when the iterator is
156 * dereferenced or advanced.
157 * */
158 size_t batch_idx_;
159 /* * The current batch of data that this iterator points to. It is stored as
160 * an optional value because it may not be loaded yet. When the iterator is
161 * dereferenced or advanced, the current batch will be loaded from the
162 * dataloader if it is not already loaded.
163 * */
164 mutable std::optional<Batch> current_batch_;
165 };
166
167 /* * Returns an iterator to the beginning of the dataloader. The iterator will
168 * point to the first batch of data in the dataloader.
169 * @return An iterator to the beginning of the dataloader.
170 * */
172
173 /* * Returns an iterator to the end of the dataloader. The iterator will point
174 * to one past the last batch of data in the dataloader.
175 * @return An iterator to the end of the dataloader.
176 * */
178
179 /* * Returns the total number of batches in the dataloader. This is calculated
180 * based on the size of the dataset and the batch size.
181 * @return The total number of batches in the dataloader.
182 * */
183 size_t size() const;
184
185 /* * Sets the number of worker threads to use for loading batches. This will
186 * affect the performance of the dataloader, as more worker threads can load
187 * batches in parallel, but may also increase the overhead of thread
188 * management. The default number of worker threads is 4.
189 * @param num_workers The number of worker threads to use for loading batches.
190 * */
191 void set_num_workers(size_t num_workers);
192
193 /* * Sets the number of batches to prefetch in the background. This will
194 * affect the performance of the dataloader, as more prefetch batches can
195 * reduce the waiting time for batches to be loaded, but may also increase the
196 * memory usage of the dataloader. The default number of prefetch batches is
197 * 2.
198 * @param prefetch_batches The number of batches to prefetch in the
199 * background.
200 * */
201 void set_prefetch_batches(size_t prefetch_batches);
202
203 /* * Returns the number of worker threads currently used for loading batches.
204 * @return The number of worker threads currently used for loading batches.
205 * */
206 size_t get_num_workers() const;
207
208 /* * Returns the number of batches currently prefetched in the background.
209 * @return The number of batches currently prefetched in the background.
210 * */
211 size_t get_prefetch_batches() const;
212
213 private:
214 /* * Starts the worker threads for loading batches. This function is called at
215 * the beginning of each epoch to initialize the worker threads and start
216 * loading batches from the dataset. The worker threads will continue to load
217 * batches in the background until the epoch is finished or the dataloader is
218 * destroyed.
219 * */
220 void start_epoch();
221
222 /* * Stops the worker threads for loading batches. This function is called at
223 * the end of each epoch or when the dataloader is destroyed to signal the
224 * worker threads to stop loading batches and exit. The worker threads will
225 * check for this signal and exit gracefully when it is set.
226 * */
227 void stop();
228
229 /* * The function that each worker thread runs to load batches of data from
230 * the dataset. Each worker thread will continuously load batches of data from
231 * the dataset and add them to the batch queue until the epoch is finished or
232 * the dataloader is destroyed. The worker threads will synchronize access to
233 * the batch queue using mutexes and condition variables to ensure thread
234 * safety.
235 * */
236 void worker_thread();
237
238 /* * Loads a batch of data from the dataset for the given batch index. This
239 * function is called by the worker threads to load batches of data from the
240 * dataset. It will calculate the indices of the samples in the batch based on
241 * the batch index and batch size, and then load the corresponding samples
242 * from the dataset. The loaded batch will be returned as a vector of pairs.
243 * @param batch_idx The index of the batch to load.
244 * @return A vector of pairs representing the loaded batch of data.
245 * */
246 Batch load_batch(size_t batch_idx);
247
248 /* * Retrieves a batch of data from the batch queue for the given batch index.
249 * This function is called by the iterator to retrieve batches of data from
250 * the dataloader. It will check if the requested batch is already loaded in
251 * the batch queue, and if so, it will return it. If the requested batch is
252 * not loaded yet, it will wait for the worker threads to load it and add it
253 * to the batch queue. The function will return an optional value, which will
254 * be empty if the epoch is finished or if the dataloader is destroyed.
255 * @param batch_idx The index of the batch to retrieve.
256 * @return An optional value containing the retrieved batch of data, or empty
257 * if the epoch is finished or if the dataloader is destroyed.
258 * */
259 std::optional<Batch> get_batch(size_t batch_idx);
260
261 /* * The dataset that this dataloader loads data from. It is a reference to an
262 * IDataset object, which provides the interface for accessing the samples in
263 * the dataset. The dataloader will use this dataset to load batches of data
264 * in the worker threads.
265 * */
266 IDataset &dataset_;
267
268 /* * The batch size that this dataloader uses to load batches of data. It
269 * determines how many samples are included in each batch that the dataloader
270 * loads from the dataset. The batch size is set at the construction of the
271 * dataloader and cannot be changed afterwards.
272 * */
273 size_t batch_size_;
274
275 /* * Whether to shuffle the dataset at the beginning of each epoch. If true,
276 * the dataloader will shuffle the indices of the samples in the dataset at
277 * the beginning of each epoch, which will result in different batches being
278 * loaded in each epoch. If false, the dataloader will load batches in a fixed
279 * order based on the original order of the samples in the dataset.
280 * */
281 bool shuffle_;
282
283 /* * The number of worker threads to use for loading batches. This determines
284 * how many threads will be running in the background to load batches of data
285 * from the dataset. More worker threads can load batches in parallel, but may
286 * also increase the overhead of thread management. The default number of
287 * worker threads is 4.
288 * */
289 size_t num_workers_;
290
291 /* * The number of batches to prefetch in the background. This determines how
292 * many batches will be loaded in advance by the worker threads while the user
293 * is consuming the batches. More prefetch batches can reduce the waiting time
294 * for batches to be loaded, but may also increase the memory usage of the
295 * dataloader. The default number of prefetch batches is 2.
296 * */
297 size_t prefetch_batches_;
298
299 /* * The total number of samples in the dataset. This is obtained from the
300 * dataset object and is used to calculate the total number of batches in the
301 * dataloader.
302 * */
303 size_t dataset_size_;
304
305 /* * The total number of batches in the dataloader. This is calculated based
306 * on the size of the dataset and the batch size. It is used to determine how
307 * many batches are available for iteration in the dataloader.
308 * */
309 size_t num_batches_;
310
311 /* * The indices of the samples in the dataset. This is a vector of size equal
312 * to the number of samples in the dataset, containing the indices of the
313 * samples in the original order. If shuffle is true, this vector will be
314 * shuffled at the beginning of each epoch to provide different batches in
315 * each epoch.
316 * */
317 std::vector<size_t> indices_;
318
319 /* * The worker threads that are responsible for loading batches of data from
320 * the dataset. This is a vector of threads that are created at the beginning
321 * of each epoch and run the worker_thread function to load batches in the
322 * background. The worker threads will continue to run until the epoch is
323 * finished or the dataloader is destroyed.
324 * */
325 std::vector<std::thread> workers_;
326
327 /* * The batch queue that holds the batches of data that have been loaded by
328 * the worker threads. This is a queue of pairs, where each pair consists of a
329 * batch index and the corresponding batch of data. The worker threads will
330 * add batches to this queue as they load them, and the iterator will retrieve
331 * batches from this queue when requested. The batch queue is synchronized
332 * using mutexes and condition variables to ensure thread safety.
333 * */
334 std::queue<std::pair<size_t, Batch>> batch_queue_;
335
336 /* * Mutex and condition variables for synchronizing access to the batch queue
337 * and signaling the worker threads. The queue_mutex is used to protect access
338 * to the batch_queue, while the queue_cv is used to signal the worker threads
339 * when a new batch is added to the queue, and the ready_cv is used to signal
340 * the iterator when a batch is ready to be retrieved from the queue.
341 * */
342 std::mutex queue_mutex_;
343
344 /* * The condition variable used to signal the worker threads when a new batch
345 * is added to the batch queue. The worker threads will wait on this condition
346 * variable when they are idle and will be notified when a new batch is added
347 * to the queue, allowing them to wake up and continue loading batches.
348 * */
349 std::condition_variable queue_cv_;
350
351 /* * The condition variable used to signal the iterator when a batch is ready
352 * to be retrieved from the batch queue. The iterator will wait on this
353 * condition variable when it is waiting for a batch to be loaded, and will be
354 * notified when a new batch is added to the queue, allowing it to wake up and
355 * retrieve the batch.
356 * */
357 std::condition_variable ready_cv_;
358
359 /* * Atomic flags for controlling the worker threads and the epoch state. The
360 * stop_workers_ flag is used to signal the worker threads to stop loading
361 * batches and exit when the epoch is finished or the dataloader is destroyed.
362 * The epoch_finished_ flag is used to indicate whether the current epoch is
363 * finished, which can be checked by the worker threads to determine when to
364 * stop loading batches.
365 * */
366 std::atomic<bool> stop_workers_;
367
368 /* * The epoch_finished_ flag is used to indicate whether the current epoch is
369 * finished. It is set to true when the iteration over the dataloader is
370 * complete or when the dataloader is destroyed, and it is checked by the
371 * worker threads to determine when to stop loading batches. When this flag is
372 * set to true, the worker threads will stop loading batches and exit
373 * gracefully.
374 * */
375 std::atomic<bool> epoch_finished_;
376
377 /* * The current batch index that the iterator is pointing to. This is used by
378 * the iterator to determine which batch to load from the dataloader when it
379 * is dereferenced or advanced. The current batch index is updated by the
380 * iterator as it advances through the batches, and it is used to retrieve the
381 * correct batch from the dataloader when requested.
382 * */
383 size_t current_batch_idx_;
384};
385
386} // namespace ReUseX::vision
Iterator(const Iterator &other)
Iterator & operator=(Iterator &&)=default
Iterator(Dataloader *loader, size_t batch_idx)
bool operator!=(const Iterator &other) const
bool operator==(const Iterator &other) const
std::input_iterator_tag iterator_category
Iterator & operator=(const Iterator &other)
size_t get_prefetch_batches() const
void set_prefetch_batches(size_t prefetch_batches)
Dataloader(IDataset &dataset, size_t batch_size, bool shuffle=false, size_t num_workers=4, size_t prefetch_batches=2)
void set_num_workers(size_t num_workers)
std::span< Pair > BatchView
size_t get_num_workers() const
std::vector< Pair > Batch
std::pair< std::unique_ptr< IData >, size_t > Pair
Definition IDataset.hpp:41