use std::{future::Future, time::Duration}; use anyhow::Result; use futures_lite::prelude::*; use lapin::{options::{BasicCancelOptions, BasicConsumeOptions, BasicQosOptions}, types::{DeliveryTag, FieldTable, ShortString, ShortUInt}, Channel, Connection}; use serde::de::DeserializeOwned; use tokio::{sync::mpsc::UnboundedSender, task::JoinSet}; use tokio_util::task::TaskTracker; use crate::ack::{AckBatcher, AckResult}; pub struct Processor { conn: Connection, chan: Channel, cons_tags: Vec, join_set: JoinSet>, ack_sink: UnboundedSender<(DeliveryTag, AckResult)>, } impl Processor { pub async fn new(conn: Connection, prefetch_count: ShortUInt) -> lapin::Result { let chan = conn.create_channel().await?; chan.basic_qos(prefetch_count, BasicQosOptions::default()).await?; let mut join_set = JoinSet::new(); let (ack_sink, ack_batcher) = AckBatcher::new(Duration::from_millis(100), chan.clone()); join_set.spawn(ack_batcher.run()); Ok(Self { conn, chan, cons_tags: vec![], ack_sink, join_set, }) } pub async fn listen>>( &mut self, queue: &str, consumer_tag: &str, f: impl 'static + Send + Fn(T) -> F, ) -> lapin::Result<()> { let mut consumer = self.chan.basic_consume( queue, consumer_tag, BasicConsumeOptions::default(), FieldTable::default() ).await?; self.cons_tags.push(consumer.tag()); let ack_batcher = self.ack_sink.clone(); self.join_set.spawn(async move { let tracker = TaskTracker::new(); while let Some(delivery) = consumer.try_next().await? { if let Ok(data) = serde_json::from_slice(&delivery.data) { let ack_batcher = ack_batcher.clone(); let fut = f(data); tracker.spawn(async move { ack_batcher.send((delivery.delivery_tag, fut.await.map_err(|_| true))).unwrap(); }); } else { ack_batcher.send((delivery.delivery_tag, Err(false))).unwrap(); } } tracker.close(); tracker.wait().await; lapin::Result::Ok(()) }); Ok(()) } pub async fn shutdown(self) -> lapin::Result<()> { let Self { cons_tags, conn, chan, mut join_set, ack_sink: ack_batcher } = self; for cons_tag in cons_tags { let chan = chan.clone(); // Required because of 'static on tokio::spawn. Might want to switch to join_all or just do serially join_set.spawn(async move { chan.basic_cancel(cons_tag.as_str(), BasicCancelOptions::default()).await }); } drop(ack_batcher); while let Some(res) = join_set.join_next().await { res.unwrap()?; } conn.close(0, "").await } }