1
Fork 0

await tasks in shutdown (not proven to work)

This commit is contained in:
Andy Killorin 2023-12-24 17:55:26 -06:00
parent 923fa7aa9b
commit caef523357
Signed by: ank
GPG key ID: B6241CA3B552BCA4
4 changed files with 51 additions and 12 deletions

View file

@ -16,7 +16,7 @@ use rstar::RTree;
use names::Name; use names::Name;
use tasks::Scheduler; use tasks::Scheduler;
use tokio::{sync::{ use tokio::{sync::{
RwLock, mpsc, OnceCell, Mutex RwLock, mpsc, OnceCell, Mutex, watch
}, fs, time::Instant}; }, fs, time::Instant};
use turtle::{Turtle, TurtleCommander}; use turtle::{Turtle, TurtleCommander};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@ -54,7 +54,9 @@ async fn main() -> Result<(), Error> {
log4rs::init_file(SAVE.get().unwrap().join("log.yml"), Default::default())?; log4rs::init_file(SAVE.get().unwrap().join("log.yml"), Default::default())?;
let state = read_from_disk().await?; let (kill_send, kill_recv) = watch::channel(());
let state = read_from_disk(kill_send).await?;
let state = SharedControl::new(RwLock::new(state)); let state = SharedControl::new(RwLock::new(state));
@ -67,13 +69,13 @@ async fn main() -> Result<(), Error> {
let listener = tokio::net::TcpListener::bind(("0.0.0.0", *PORT.get().unwrap())) let listener = tokio::net::TcpListener::bind(("0.0.0.0", *PORT.get().unwrap()))
.await.unwrap(); .await.unwrap();
let server = safe_kill::serve(server, listener).await; safe_kill::serve(server, listener, kill_recv).await;
info!("writing"); info!("writing");
write_to_disk(&*state.read().await).await?; write_to_disk(&*state.read().await).await?;
info!("written"); info!("written");
server.closed().await; state.write().await.kill.closed().await;
Ok(()) Ok(())
} }
@ -101,7 +103,7 @@ async fn write_to_disk(state: &LiveState) -> anyhow::Result<()> {
Ok(()) Ok(())
} }
async fn read_from_disk() -> anyhow::Result<LiveState> { async fn read_from_disk(kill: watch::Sender<()>) -> anyhow::Result<LiveState> {
let turtles = match tokio::fs::OpenOptions::new() let turtles = match tokio::fs::OpenOptions::new()
.read(true) .read(true)
.open(SAVE.get().unwrap().join("turtles.json")) .open(SAVE.get().unwrap().join("turtles.json"))
@ -154,7 +156,7 @@ async fn read_from_disk() -> anyhow::Result<LiveState> {
depots, depots,
}; };
let mut live = LiveState::from_save(saved, scheduler); let mut live = LiveState::from_save(saved, scheduler, kill);
for turtle in live.turtles.iter() { for turtle in live.turtles.iter() {
live.tasks.add_turtle(&TurtleCommander::new(turtle.read().await.name,&live).await.unwrap()) live.tasks.add_turtle(&TurtleCommander::new(turtle.read().await.name,&live).await.unwrap())
@ -177,6 +179,7 @@ struct LiveState {
world: blocks::World, world: blocks::World,
depots: Depots, depots: Depots,
started: Instant, started: Instant,
kill: watch::Sender<()>,
} }
impl LiveState { impl LiveState {
@ -189,7 +192,7 @@ impl LiveState {
SavedState { turtles, world: self.world.tree().await, depots } SavedState { turtles, world: self.world.tree().await, depots }
} }
fn from_save(save: SavedState, scheduler: Scheduler) -> Self { fn from_save(save: SavedState, scheduler: Scheduler, sender: watch::Sender<()>) -> Self {
let mut turtles = Vec::new(); let mut turtles = Vec::new();
for turtle in save.turtles.into_iter() { for turtle in save.turtles.into_iter() {
let (tx, rx) = mpsc::channel(1); let (tx, rx) = mpsc::channel(1);
@ -200,6 +203,7 @@ impl LiveState {
Self { turtles: turtles.into_iter().map(|t| Arc::new(RwLock::new(t))).collect(), tasks: scheduler, world: World::from_tree(save.world), Self { turtles: turtles.into_iter().map(|t| Arc::new(RwLock::new(t))).collect(), tasks: scheduler, world: World::from_tree(save.world),
depots, depots,
started: Instant::now(), started: Instant::now(),
kill:sender,
} }
} }

View file

@ -20,9 +20,7 @@ use tokio::sync::watch::Sender;
use axum::Router; use axum::Router;
pub(crate) async fn serve(server: Router, listener: TcpListener) -> Sender<()> { pub(crate) async fn serve(server: Router, listener: TcpListener, close_rx: watch::Receiver<()>) {
let (close_tx, close_rx) = watch::channel(());
loop { loop {
let (socket, _) = tokio::select! { let (socket, _) = tokio::select! {
result = listener.accept() => { result = listener.accept() => {
@ -69,8 +67,6 @@ pub(crate) async fn serve(server: Router, listener: TcpListener) -> Sender<()> {
} }
drop(listener); drop(listener);
close_tx
} }
pub(crate) async fn shutdown_signal() { pub(crate) async fn shutdown_signal() {

View file

@ -1,5 +1,6 @@
use log::{info, trace}; use log::{info, trace};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use tokio::sync::oneshot;
use tokio::task::{JoinHandle, AbortHandle}; use tokio::task::{JoinHandle, AbortHandle};
use crate::names::Name; use crate::names::Name;
@ -24,6 +25,8 @@ pub struct Scheduler {
#[serde(skip)] #[serde(skip)]
turtles: Vec<(TurtleCommander, Option<AbortHandle>)>, turtles: Vec<(TurtleCommander, Option<AbortHandle>)>,
tasks: Vec<Box<dyn Task>>, tasks: Vec<Box<dyn Task>>,
#[serde(skip)]
shutdown: Option<oneshot::Sender<()>>,
} }
impl Default for Scheduler { impl Default for Scheduler {
@ -31,6 +34,7 @@ impl Default for Scheduler {
Self { Self {
turtles: Vec::new(), turtles: Vec::new(),
tasks: Vec::new(), tasks: Vec::new(),
shutdown:None,
} }
} }
} }
@ -63,6 +67,14 @@ impl Scheduler {
} }
} }
if self.shutdown.is_some() {
if !self.turtles.iter().any(|t| t.1.is_some()) {
self.shutdown.take().unwrap().send(()).unwrap();
}
return;
}
let mut free_turtles: Vec<&mut (TurtleCommander, Option<AbortHandle>)> = let mut free_turtles: Vec<&mut (TurtleCommander, Option<AbortHandle>)> =
self.turtles.iter_mut().filter(|t| t.1.is_none()).collect(); self.turtles.iter_mut().filter(|t| t.1.is_none()).collect();
@ -108,4 +120,10 @@ impl Scheduler {
} }
Some(()) Some(())
} }
pub fn shutdown(&mut self) -> oneshot::Receiver<()>{
let (send, recv) = oneshot::channel();
self.shutdown = Some(send);
recv
}
} }

View file

@ -50,6 +50,7 @@ pub fn turtle_api() -> Router<SharedControl> {
.route("/createMine", post(dig)) .route("/createMine", post(dig))
.route("/registerDepot", post(new_depot)) .route("/registerDepot", post(new_depot))
.route("/pollScheduler", get(poll)) .route("/pollScheduler", get(poll))
.route("/shutdown", get(shutdown)) // probably tramples the rfc
.route("/updateAll", get(update_turtles)) .route("/updateAll", get(update_turtles))
} }
@ -138,6 +139,26 @@ pub(crate) async fn poll(
"ACK" "ACK"
} }
pub(crate) async fn shutdown(
State(state): State<SharedControl>,
) -> &'static str {
let signal = {
let mut state = state.write().await;
let signal = state.tasks.shutdown();
state.tasks.poll().await;
signal
};
info!("waiting for tasks to finish");
signal.await.unwrap();
let state = state.write().await;
info!("waiting for connections to finish");
state.kill.send(()).unwrap();
"ACK"
}
pub(crate) async fn fell( pub(crate) async fn fell(
State(state): State<SharedControl>, State(state): State<SharedControl>,
Json(req): Json<Vec3>, Json(req): Json<Vec3>,