diff --git a/server/src/mine.rs b/server/src/mine.rs index d275f21..e20e11b 100644 --- a/server/src/mine.rs +++ b/server/src/mine.rs @@ -314,7 +314,7 @@ struct ChunkedTask { #[serde(skip_deserializing)] head: Arc, // highest active chunk #[serde(skip)] - in_flight: Arc>>, // must remain sorted + canceled: Arc>>, // must remain sorted max: i32, } @@ -323,55 +323,49 @@ impl ChunkedTask { Self { confirmed: Default::default(), head: Default::default(), - in_flight: Default::default(), + canceled: Default::default(), max: parts, } } fn done(&self) -> bool { let backstop = self.confirmed.load(Ordering::SeqCst); - backstop >= self.max + backstop + 1 >= self.max } fn allocated(&self) -> bool { let front = self.head.load(Ordering::SeqCst); front + 1 >= self.max } + async fn next_chunk(&self) -> Option { - let mut in_flight = self.in_flight.clone().write_owned().await; + let mut cancelled = self.canceled.clone().write_owned().await; - tracing::trace!("running: {:?}", in_flight); - - let backstop = self.confirmed.load(Ordering::SeqCst); - - // we have a mutex anyway - if let Some(highest) = in_flight.last() { - self.head.store(*highest, Ordering::SeqCst); + if let Some(chunk) = cancelled.pop() { + return Some(chunk); } - for i in backstop..self.max { - if !in_flight.contains(&i) { - in_flight.push(i); - in_flight.sort_unstable(); - info!("next: {i}"); - return Some(i); - } - } + let head = self.head.fetch_add(1, Ordering::AcqRel); - return None; + if head < self.max { + Some(head) + } else { + None + } } async fn mark_done(&self, chunk: i32) { - let mut in_flight = self.in_flight.write().await; + let canceled = self.canceled.read().await; - let min = in_flight.iter().max() == Some(&chunk); + let min = match canceled.iter().min() { + None => true, + Some(minima) => chunk < *minima, + }; - in_flight.retain(|c| c != &chunk); - - if min { // make sure that head is no less than min + if min { loop { let curr = self.confirmed.load(Ordering::SeqCst); - if let Ok(_) = self.confirmed.compare_exchange(curr, curr.max(chunk+1), Ordering::AcqRel, Ordering::SeqCst) { + if let Ok(_) = self.confirmed.compare_exchange(curr, curr.max(chunk), Ordering::AcqRel, Ordering::SeqCst) { break; } } @@ -379,8 +373,15 @@ impl ChunkedTask { } async fn cancel(&self, chunk: i32) { - let mut in_flight = self.in_flight.write().await; - in_flight.retain(|c| c != &chunk); + let mut in_flight = self.canceled.write().await; + let max = self.head.load(Ordering::SeqCst); + if chunk < max { + in_flight.push(chunk); + in_flight.sort_unstable(); + } + else { + error!("attempted to cancel a job that hasn't happened yet"); + } } } @@ -415,10 +416,25 @@ mod tests { async fn cancel_replay() { let tracker = ChunkedTask::new(5); assert_eq!(tracker.next_chunk().await, Some(0)); + tracker.mark_done(0).await; assert_eq!(tracker.next_chunk().await, Some(1)); - assert_eq!(tracker.next_chunk().await, Some(2)); + tracker.mark_done(1).await; tracker.cancel(2).await; assert_eq!(tracker.next_chunk().await, Some(2)); + assert_eq!(tracker.done(), false); + tracker.mark_done(2).await; + assert_eq!(tracker.next_chunk().await, Some(3)); + tracker.mark_done(3).await; + tracker.cancel(2).await; + assert_eq!(tracker.next_chunk().await, Some(2)); + assert_eq!(tracker.next_chunk().await, Some(4)); + tracker.cancel(1).await; + assert_eq!(tracker.next_chunk().await, Some(1)); + assert_eq!(tracker.done(), false); + assert_eq!(tracker.next_chunk().await, None); + assert_eq!(tracker.allocated(), true); + tracker.mark_done(4).await; + assert_eq!(tracker.done(), true); } #[tokio::test]