diff --git a/examples/rt685s-evk/src/bin/fw_update.rs b/examples/rt685s-evk/src/bin/fw_update.rs index 6e5dde8..6c59080 100644 --- a/examples/rt685s-evk/src/bin/fw_update.rs +++ b/examples/rt685s-evk/src/bin/fw_update.rs @@ -27,10 +27,10 @@ bind_interrupts!(struct Irqs { type Bus<'a> = I2cDevice<'a, NoopRawMutex, I2cMaster<'a, Async>>; type Controller<'a> = pd_controller::controller::Controller>; -type Interrupt<'a> = pd_controller::Interrupt<'a, NoopRawMutex, Bus<'a>>; +type InterruptProcessor<'a> = pd_controller::interrupt::InterruptProcessor<'a, NoopRawMutex, Bus<'a>>; #[embassy_executor::task] -async fn interrupt_task(mut int_in: Input<'static>, mut interrupt: Interrupt<'static>) { +async fn interrupt_task(mut int_in: Input<'static>, mut interrupt: InterruptProcessor<'static>) { pd_controller::task::interrupt_task(&mut int_in, [&mut interrupt].as_mut_slice()).await; } @@ -47,15 +47,15 @@ async fn main(spawner: Spawner) { let device = I2cDevice::new(bus); static CONTROLLER: StaticCell> = StaticCell::new(); - let controller = CONTROLLER.init(Controller::new_tps66994(device, ADDR0).unwrap()); - let (mut pd, interrupt) = controller.make_parts(); + let controller = CONTROLLER.init(Controller::new_tps66994(device, Default::default(), ADDR0).unwrap()); + let (mut pd, interrupt_processor, _interrupt_receiver) = controller.make_parts(); let mut delay = Delay; info!("Resetting PD controller"); pd.reset(&mut delay).await.unwrap(); info!("Spawing PD interrupt task"); - spawner.spawn(interrupt_task(int_in, interrupt).unwrap()); + spawner.spawn(interrupt_task(int_in, interrupt_processor).unwrap()); let pd_fw_bytes = [0u8].as_slice(); //include_bytes!("../../fw.bin").as_slice(); diff --git a/examples/rt685s-evk/src/bin/plug_status.rs b/examples/rt685s-evk/src/bin/plug_status.rs index bc6b1c0..2010157 100644 --- a/examples/rt685s-evk/src/bin/plug_status.rs +++ b/examples/rt685s-evk/src/bin/plug_status.rs @@ -26,10 +26,10 @@ bind_interrupts!(struct Irqs { type Bus<'a> = I2cDevice<'a, NoopRawMutex, I2cMaster<'a, Async>>; type Controller<'a> = pd_controller::controller::Controller>; -type Interrupt<'a> = pd_controller::Interrupt<'a, NoopRawMutex, Bus<'a>>; +type InterruptProcessor<'a> = pd_controller::interrupt::InterruptProcessor<'a, NoopRawMutex, Bus<'a>>; #[embassy_executor::task] -async fn interrupt_task(mut int_in: Input<'static>, mut interrupt: Interrupt<'static>) { +async fn interrupt_task(mut int_in: Input<'static>, mut interrupt: InterruptProcessor<'static>) { pd_controller::task::interrupt_task(&mut int_in, [&mut interrupt].as_mut_slice()).await; } @@ -46,17 +46,17 @@ async fn main(spawner: Spawner) { let device = I2cDevice::new(bus); static CONTROLLER: StaticCell> = StaticCell::new(); - let controller = CONTROLLER.init(Controller::new_tps66994(device, ADDR0).unwrap()); - let (mut pd, interrupt) = controller.make_parts(); + let controller = CONTROLLER.init(Controller::new_tps66994(device, Default::default(), ADDR0).unwrap()); + let (mut pd, interrupt_processor, mut interrupt_receiver) = controller.make_parts(); info!("Spawing PD interrupt task"); - spawner.spawn(interrupt_task(int_in, interrupt).unwrap()); + spawner.spawn(interrupt_task(int_in, interrupt_processor).unwrap()); loop { let mut plug_event_mask = IntEventBus1::new_zero(); plug_event_mask.set_plug_event(true); - let flags = pd - .wait_interrupt_any(false, [plug_event_mask; MAX_SUPPORTED_PORTS]) + let flags = interrupt_receiver + .wait_any_masked(false, [plug_event_mask; MAX_SUPPORTED_PORTS]) .await; for (i, flag) in flags.iter().enumerate().take(pd.num_ports()) { diff --git a/src/asynchronous/embassy/interrupt.rs b/src/asynchronous/embassy/interrupt.rs new file mode 100644 index 0000000..f29dd83 --- /dev/null +++ b/src/asynchronous/embassy/interrupt.rs @@ -0,0 +1,446 @@ +//! Interrupt related code. +use core::array::from_fn; + +use embassy_sync::blocking_mutex::raw::RawMutex; +use embassy_time::{with_timeout, Duration}; +use embedded_hal::digital::InputPin; +use embedded_hal_async::i2c::I2c; +use embedded_usb_pd::{Error, LocalPortId, PdError}; +use itertools::izip; + +use crate::asynchronous::embassy::controller::Controller; +use crate::registers::field_sets::IntEventBus1; +use crate::{error, trace, warn, MAX_SUPPORTED_PORTS}; + +/// Configuration for [`InterruptProcessor`] +#[non_exhaustive] +pub struct Config { + pub interrupt_timeout: Duration, +} + +impl Default for Config { + fn default() -> Self { + Self { + interrupt_timeout: Duration::from_millis(100), + } + } +} + +/// Struct for processing interrupts from the TPS6699x. +pub struct InterruptProcessor<'a, M: RawMutex, B: I2c> { + pub(super) controller: &'a Controller, +} + +impl<'a, M: RawMutex, B: I2c> InterruptProcessor<'a, M, B> { + /// Process interrupts + pub async fn process_interrupt( + &mut self, + int: &mut impl InputPin, + ) -> Result<[IntEventBus1; MAX_SUPPORTED_PORTS], Error> { + let timeout = self.controller.config.interrupt_processor_config.interrupt_timeout; + let mut flags = self + .controller + .interrupt_waker + .try_take() + .unwrap_or([IntEventBus1::new_zero(); MAX_SUPPORTED_PORTS]); + + let interrupts_enabled = self.controller.interrupts_enabled(); + let mut inner = self.controller.inner.lock().await; + + // Note: `interrupts_enabled` and `flags` are both of size MAX_SUPPORTED_PORTS and so + // will always have a 1:1 mapping. If `num_ports` ever returns a value larger than + // MAX_SUPPORTED_PORTS, `port` will simply be capped at MAX_SUPPORTED_PORTS. + for (port, (interrupt_enabled, flag, command_complete)) in izip!( + interrupts_enabled.iter(), + flags.iter_mut(), + self.controller.command_complete.iter() + ) + .take(self.controller.num_ports) + .enumerate() + { + let port_id = LocalPortId(port as u8); + + if !interrupt_enabled { + trace!("{:?}: Interrupts disabled", port_id); + continue; + } + + match int.is_high() { + Ok(true) => { + // Early exit if checking the last port cleared the interrupt + trace!("Interrupt line is high, exiting"); + break; + } + Err(_) => { + error!("Failed to read interrupt line"); + return PdError::Failed.into(); + } + _ => {} + } + + match with_timeout(timeout, inner.clear_interrupt(port_id)).await { + Ok(res) => match res { + Ok(event) => { + *flag |= event; + if event.cmd_1_completed() { + command_complete.signal(()); + } + } + Err(_) => { + error!("{:?}: clear_interrupt failed", port_id); + continue; + } + }, + Err(_) => { + error!("{:?}: clear_interrupt timeout", port_id); + continue; + } + } + } + + self.controller.interrupt_waker.signal(flags); + Ok(flags) + } +} + +/// Restores the original interrupt state when dropped +pub struct InterruptGuard<'a, M: RawMutex, B: I2c> { + target_state: [bool; MAX_SUPPORTED_PORTS], + controller: &'a Controller, +} + +impl<'a, M: RawMutex, B: I2c> InterruptGuard<'a, M, B> { + pub(super) fn new(controller: &'a Controller, enabled: [bool; MAX_SUPPORTED_PORTS]) -> Self { + let target_state = controller.interrupts_enabled(); + controller.enable_interrupts(enabled); + Self { + target_state, + controller, + } + } +} + +impl Drop for InterruptGuard<'_, M, B> { + fn drop(&mut self) { + self.controller.enable_interrupts(self.target_state); + } +} + +impl crate::asynchronous::interrupt::InterruptGuard for InterruptGuard<'_, M, B> {} + +/// Struct to ensure drop-safety of [`InterruptReceiver::wait_any_masked`] +/// +/// This struct re-signals any unhandled interrupts on drop. +struct AccumulatedFlagsAny<'a, M: RawMutex, B: I2c> { + controller: &'a Controller, + accumulated_flags: [IntEventBus1; MAX_SUPPORTED_PORTS], + masks: [IntEventBus1; MAX_SUPPORTED_PORTS], +} + +impl<'a, M: RawMutex, B: I2c> AccumulatedFlagsAny<'a, M, B> { + fn new(controller: &'a Controller, masks: [IntEventBus1; MAX_SUPPORTED_PORTS]) -> Self { + AccumulatedFlagsAny { + controller, + accumulated_flags: [IntEventBus1::new_zero(); MAX_SUPPORTED_PORTS], + masks, + } + } + + fn accumulate( + &mut self, + flags: [IntEventBus1; MAX_SUPPORTED_PORTS], + ) -> Option<[IntEventBus1; MAX_SUPPORTED_PORTS]> { + let mut done = false; + for (&flags, &mask, accumulated) in izip!(flags.iter(), self.masks.iter(), self.accumulated_flags.iter_mut(),) { + *accumulated |= flags; + let consumed_flags = flags & mask; + if consumed_flags != IntEventBus1::new_zero() { + done = true; + } + } + + if done { + // Panic safety: the return type, `accumulated_flags`, and `mask` are all of size MAX_SUPPORTED_PORTS + // so this will never index out of bounds + #[allow(clippy::indexing_slicing)] + let handled = from_fn(|i| self.accumulated_flags[i] & self.masks[i]); + // Put unhandled flags back for signaling in `drop()` + self.accumulated_flags = from_fn(|i| self.accumulated_flags[i] & !self.masks[i]); + Some(handled) + } else { + None + } + } +} + +impl Drop for AccumulatedFlagsAny<'_, M, B> { + fn drop(&mut self) { + // Catch any flags that may have happened since the last accumulate. + let new = self + .controller + .interrupt_waker + .try_take() + .unwrap_or([IntEventBus1::new_zero(); MAX_SUPPORTED_PORTS]); + // Panic safety: `unhandled`, `accumulated_flags`, and `mask` are all of size MAX_SUPPORTED_PORTS + // so this will never index out of bounds + #[allow(clippy::indexing_slicing)] + let unhandled = from_fn(|i| self.accumulated_flags[i] | new[i]); + + // Put back any unhandled interrupt flags for future processing + if unhandled.iter().any(|&f| f != IntEventBus1::new_zero()) { + // If there are unhandled flags, signal them for future processing + trace!("Signaling unhandled interrupt flags: {:?}", unhandled); + self.controller.interrupt_waker.signal(unhandled); + } + } +} + +/// Struct used to receive interrupts from the TPS6699x. +/// +/// +pub struct InterruptReceiver<'a, M: RawMutex, B: I2c> { + pub(super) controller: &'a Controller, +} + +impl<'a, M: RawMutex, B: I2c> InterruptReceiver<'a, M, B> { + /// Wait for an interrupt to occur. + /// + /// Drop safety: Safe, unhandled interrupts will be re-signaled. + pub async fn wait_any(&mut self, clean_current: bool) -> [IntEventBus1; MAX_SUPPORTED_PORTS] { + self.wait_any_masked(clean_current, [IntEventBus1::all(); MAX_SUPPORTED_PORTS]) + .await + } + + /// Wait for an interrupt to occur that matches any bits in the given mask. + pub async fn wait_any_masked( + &mut self, + clear_current: bool, + mask: [IntEventBus1; MAX_SUPPORTED_PORTS], + ) -> [IntEventBus1; MAX_SUPPORTED_PORTS] { + // No interrupts set, return immediately because there is nothing to wait for + // Also log a warning because this likely isn't what the user intended + if mask == [IntEventBus1::new_zero(); MAX_SUPPORTED_PORTS] { + warn!("Interrupt masks are empty, returning immediately"); + return [IntEventBus1::new_zero(); MAX_SUPPORTED_PORTS]; + } + + if clear_current { + self.controller.interrupt_waker.reset(); + } + + let mut accumulated_flags = AccumulatedFlagsAny::new(self.controller, mask); + loop { + let flags = self.controller.interrupt_waker.wait().await; + if let Some(flags) = accumulated_flags.accumulate(flags) { + return flags; + } + } + } +} + +#[cfg(test)] +mod test { + use embassy_sync::blocking_mutex::raw::NoopRawMutex; + use embassy_time::{with_timeout, Duration, TimeoutError}; + use embedded_hal_mock::eh1::i2c::Mock; + use static_cell::StaticCell; + + use super::*; + use crate::asynchronous::embassy::controller::Controller; + use crate::ADDR0; + + /// Tests `wait_any_masked` with a mask for both ports. + #[tokio::test] + async fn test_wait_any_masked_both() { + static CONTROLLER: StaticCell> = StaticCell::new(); + let controller = CONTROLLER.init(Controller::new_tps66994(Mock::new(&[]), Default::default(), ADDR0).unwrap()); + let (pd, _processor, mut receiver) = controller.make_parts(); + + let mut port0 = IntEventBus1::new_zero(); + port0.set_new_consumer_contract(true); + port0.set_sink_ready(true); + port0.set_cmd_1_completed(true); + + let mut port1 = IntEventBus1::new_zero(); + port1.set_plug_event(true); + port1.set_alert_message_received(true); + + pd.controller.interrupt_waker.signal([port0, port1]); + + let mut mask0 = IntEventBus1::new_zero(); + mask0.set_cmd_1_completed(true); + + let mut mask1 = IntEventBus1::new_zero(); + mask1.set_plug_event(true); + mask1.set_alert_message_received(true); + + let flags = receiver.wait_any_masked(false, [mask0, mask1]).await; + assert_eq!(flags, [mask0, mask1]); + + let mut unhandled0 = IntEventBus1::new_zero(); + unhandled0.set_new_consumer_contract(true); + unhandled0.set_sink_ready(true); + + let unhandled1 = IntEventBus1::new_zero(); + + // Should already be signaled + assert_eq!( + pd.controller.interrupt_waker.try_take().unwrap(), + [unhandled0, unhandled1] + ); + } + + /// Tests `wait_any_masked` with a mask for a single port. + #[tokio::test] + async fn test_wait_any_masked_single() { + static CONTROLLER: StaticCell> = StaticCell::new(); + let controller = CONTROLLER.init(Controller::new_tps66994(Mock::new(&[]), Default::default(), ADDR0).unwrap()); + let (pd, _processor, mut receiver) = controller.make_parts(); + + let mut port0 = IntEventBus1::new_zero(); + port0.set_new_consumer_contract(true); + port0.set_sink_ready(true); + port0.set_cmd_1_completed(true); + + let mut port1 = IntEventBus1::new_zero(); + port1.set_plug_event(true); + port1.set_alert_message_received(true); + + pd.controller.interrupt_waker.signal([port0, port1]); + + let mut mask0 = IntEventBus1::new_zero(); + mask0.set_cmd_1_completed(true); + + let mask1 = IntEventBus1::new_zero(); + + let flags = receiver.wait_any_masked(false, [mask0, mask1]).await; + assert_eq!(flags, [mask0, mask1]); + + let mut unhandled0 = IntEventBus1::new_zero(); + unhandled0.set_new_consumer_contract(true); + unhandled0.set_sink_ready(true); + + let unhandled1 = port1; + + // Should already be signaled + assert_eq!( + pd.controller.interrupt_waker.try_take().unwrap(), + [unhandled0, unhandled1] + ); + } + + /// Tests `wait_any_masked` with both masks set to zero. + #[tokio::test] + async fn test_wait_any_masked_zero_masks() { + static CONTROLLER: StaticCell> = StaticCell::new(); + let controller = CONTROLLER.init(Controller::new_tps66994(Mock::new(&[]), Default::default(), ADDR0).unwrap()); + let (pd, _processor, mut receiver) = controller.make_parts(); + + let mut port0 = IntEventBus1::new_zero(); + port0.set_new_consumer_contract(true); + port0.set_sink_ready(true); + port0.set_cmd_1_completed(true); + + let mut port1 = IntEventBus1::new_zero(); + port1.set_plug_event(true); + port1.set_alert_message_received(true); + + pd.controller.interrupt_waker.signal([port0, port1]); + + let mask0 = IntEventBus1::new_zero(); + let mask1 = IntEventBus1::new_zero(); + let flags = receiver.wait_any_masked(false, [mask0, mask1]).await; + assert_eq!(flags, [mask0, mask1]); + + // Should already be signaled with nothing changed + assert_eq!(pd.controller.interrupt_waker.try_take().unwrap(), [port0, port1]); + } + + #[tokio::test] + async fn test_wait_any_masked_timeout() { + // Port0 mocked pending interrupts + let mut port0 = IntEventBus1::new_zero(); + port0.set_new_consumer_contract(true); + + // Port1 mocked pending interrupts + let mut port1 = IntEventBus1::new_zero(); + port1.set_plug_event(true); + + static CONTROLLER: StaticCell> = StaticCell::new(); + let controller = CONTROLLER.init(Controller::new_tps66994(Mock::new(&[]), Default::default(), ADDR0).unwrap()); + let (pd, _processor, mut receiver) = controller.make_parts(); + + pd.controller.interrupt_waker.signal([port0, port1]); + + // The mask doesn't match the pending interrupts, so we should get a timeout + let mut mask0 = IntEventBus1::new_zero(); + mask0.set_cmd_1_completed(true); + + let mut mask1 = IntEventBus1::new_zero(); + mask1.set_new_provider_contract(true); + + assert_eq!( + with_timeout( + Duration::from_millis(10), + receiver.wait_any_masked(false, [mask0, mask1]) + ) + .await, + Err(TimeoutError) + ); + + // Use all mask to get leftover interrupts + let mut leftover0 = IntEventBus1::new_zero(); + leftover0.set_new_consumer_contract(true); + + let mut leftover1 = IntEventBus1::new_zero(); + leftover1.set_plug_event(true); + + let leftover_flags = with_timeout( + Duration::from_millis(10), + receiver.wait_any_masked(false, [IntEventBus1::all(), IntEventBus1::all()]), + ) + .await + .unwrap(); + assert_eq!(leftover_flags[0], leftover0); + assert_eq!(leftover_flags[1], leftover1); + } + + /// Tests `wait_any`. + #[tokio::test] + async fn test_wait_any() { + static CONTROLLER: StaticCell> = StaticCell::new(); + let controller = CONTROLLER.init(Controller::new_tps66994(Mock::new(&[]), Default::default(), ADDR0).unwrap()); + let (pd, _processor, mut receiver) = controller.make_parts(); + + let mut port0 = IntEventBus1::new_zero(); + port0.set_new_consumer_contract(true); + port0.set_sink_ready(true); + port0.set_cmd_1_completed(true); + + let mut port1 = IntEventBus1::new_zero(); + port1.set_plug_event(true); + port1.set_alert_message_received(true); + + pd.controller.interrupt_waker.signal([port0, port1]); + + let mut flags0 = IntEventBus1::new_zero(); + flags0.set_new_consumer_contract(true); + flags0.set_sink_ready(true); + flags0.set_cmd_1_completed(true); + + let mut flags1 = IntEventBus1::new_zero(); + flags1.set_plug_event(true); + flags1.set_alert_message_received(true); + + let flags = receiver.wait_any(false).await; + assert_eq!(flags, [flags0, flags1]); + + // This should timeout because there are no leftover interrupts + let leftover_flags = with_timeout( + Duration::from_millis(10), + receiver.wait_any_masked(false, [IntEventBus1::all(), IntEventBus1::all()]), + ) + .await; + assert_eq!(leftover_flags, Err(TimeoutError)); + } +} diff --git a/src/asynchronous/embassy/mod.rs b/src/asynchronous/embassy/mod.rs index c114f3e..8d63444 100644 --- a/src/asynchronous/embassy/mod.rs +++ b/src/asynchronous/embassy/mod.rs @@ -1,5 +1,4 @@ //! This module contains a high-level API uses embassy synchronization types -use core::array::from_fn; use core::future::Future; use core::iter::zip; use core::sync::atomic::AtomicBool; @@ -8,35 +7,47 @@ use bincode::config; use embassy_sync::blocking_mutex::raw::RawMutex; use embassy_sync::mutex::{Mutex, MutexGuard}; use embassy_sync::signal::Signal; -use embassy_time::{with_timeout, Duration, Timer}; -use embedded_hal::digital::InputPin; +use embassy_time::{with_timeout, Timer}; use embedded_hal_async::delay::DelayNs; use embedded_hal_async::i2c::I2c; use embedded_usb_pd::ado::{self, Ado}; use embedded_usb_pd::pdinfo::AltMode; use embedded_usb_pd::{pdo, Error, LocalPortId, PdError}; -use itertools::izip; -use super::interrupt::{self, InterruptController}; +use crate::asynchronous::embassy::interrupt::InterruptReceiver; use crate::asynchronous::internal; +use crate::asynchronous::interrupt::InterruptController; use crate::command::{gcdm, muxr, trig, vdms, Command, ReturnValue, SrdySwitch}; use crate::registers::autonegotiate_sink::AutoComputeSinkMaxVoltage; use crate::registers::field_sets::IntEventBus1; -use crate::{error, registers, trace, warn, DeviceError, Mode, MAX_SUPPORTED_PORTS}; +use crate::{error, registers, trace, DeviceError, Mode, MAX_SUPPORTED_PORTS}; pub mod fw_update; +pub mod interrupt; pub mod rx_caps; pub mod task; pub mod ucsi; pub mod controller { use super::*; + use crate::asynchronous::embassy::interrupt::InterruptProcessor; use crate::{TPS66993_NUM_PORTS, TPS66994_NUM_PORTS}; + /// Configuration for [`Controller`] + #[derive(Default)] + #[non_exhaustive] + pub struct Config { + pub interrupt_processor_config: crate::asynchronous::embassy::interrupt::Config, + } + /// Controller struct. This struct is meant to be created and then immediately broken into its parts pub struct Controller { + /// Config + pub(super) config: Config, /// Low-level TPS6699x driver pub(super) inner: Mutex>, + /// Command completion signals + pub(super) command_complete: [Signal; MAX_SUPPORTED_PORTS], /// Signal for awaiting an interrupt pub(super) interrupt_waker: Signal, /// Current interrupt state @@ -47,30 +58,44 @@ pub mod controller { impl Controller { /// Private constructor - pub fn new(bus: B, addr: [u8; MAX_SUPPORTED_PORTS], num_ports: usize) -> Result> { + fn new( + bus: B, + config: Config, + addr: [u8; MAX_SUPPORTED_PORTS], + num_ports: usize, + ) -> Result> { Ok(Self { + config, inner: Mutex::new(internal::Tps6699x::new(bus, addr, num_ports)), interrupt_waker: Signal::new(), + command_complete: [const { Signal::new() }; MAX_SUPPORTED_PORTS], interrupts_enabled: [const { AtomicBool::new(true) }; MAX_SUPPORTED_PORTS], num_ports, }) } /// Create a new controller for the TPS66993 - pub fn new_tps66993(bus: B, addr: u8) -> Result> { - Self::new(bus, [addr, 0], TPS66993_NUM_PORTS) + pub fn new_tps66993(bus: B, config: Config, addr: u8) -> Result> { + Self::new(bus, config, [addr, 0], TPS66993_NUM_PORTS) } /// Create a new controller for the TPS66994 - pub fn new_tps66994(bus: B, addr: [u8; TPS66994_NUM_PORTS]) -> Result> { - Self::new(bus, addr, TPS66994_NUM_PORTS) + pub fn new_tps66994(bus: B, config: Config, addr: [u8; TPS66994_NUM_PORTS]) -> Result> { + Self::new(bus, config, addr, TPS66994_NUM_PORTS) } /// Breaks the controller into its parts - pub fn make_parts(&mut self) -> (Tps6699x<'_, M, B>, Interrupt<'_, M, B>) { + pub fn make_parts( + &mut self, + ) -> ( + Tps6699x<'_, M, B>, + InterruptProcessor<'_, M, B>, + InterruptReceiver<'_, M, B>, + ) { let tps = Tps6699x { controller: self }; - let interrupt = Interrupt { controller: self }; - (tps, interrupt) + let interrupt = InterruptProcessor { controller: self }; + let receiver = InterruptReceiver { controller: self }; + (tps, interrupt, receiver) } /// Enable or disable interrupts for the given ports @@ -243,34 +268,6 @@ impl<'a, M: RawMutex, B: I2c> Tps6699x<'a, M, B> { self.controller.num_ports } - /// Wait for an interrupt to occur that matches any bits in the given mask. - /// - /// Drop safety: Safe, unhandled interrupts will be re-signaled. - pub async fn wait_interrupt_any( - &mut self, - clear_current: bool, - mask: [IntEventBus1; MAX_SUPPORTED_PORTS], - ) -> [IntEventBus1; MAX_SUPPORTED_PORTS] { - // No interrupts set, return immediately because there is nothing to wait for - // Also log a warning because this likely isn't what the user intended - if mask == [IntEventBus1::new_zero(); MAX_SUPPORTED_PORTS] { - warn!("Interrupt masks are empty, returning immediately"); - return [IntEventBus1::new_zero(); MAX_SUPPORTED_PORTS]; - } - - if clear_current { - self.controller.interrupt_waker.reset(); - } - - let mut accumulated_flags = AccumulatedFlagsAny::new(self.controller, mask); - loop { - let flags = self.controller.interrupt_waker.wait().await; - if let Some(flags) = accumulated_flags.accumulate(flags) { - return flags; - } - } - } - /// Execute the given command with no timeout async fn execute_command_no_timeout( &mut self, @@ -279,26 +276,25 @@ impl<'a, M: RawMutex, B: I2c> Tps6699x<'a, M, B> { indata: Option<&[u8]>, outdata: Option<&mut [u8]>, ) -> Result> { + // Size of the command_complete array is MAX_SUPPORTED_PORTS so the `get`` call below doesn't guarentee + // that the port is valid because it wouldn't catch trying to access a second port on a controller with + // only one port. + if port.0 as usize >= self.controller.num_ports { + return Err(Error::Pd(PdError::InvalidPort)); + } + + let command_complete = self + .controller + .command_complete + .get(port.0 as usize) + .ok_or(Error::Pd(PdError::InvalidPort))?; + command_complete.reset(); { let mut inner = self.lock_inner().await; inner.send_command(port, cmd, indata).await?; } - let mut cmd_complete = IntEventBus1::new_zero(); - cmd_complete.set_cmd_1_completed(true); - - let _flags = self - .wait_interrupt_any( - false, - from_fn(|i| { - if i == port.0 as usize { - cmd_complete - } else { - IntEventBus1::new_zero() - } - }), - ) - .await; + command_complete.wait().await; { let mut inner = self.lock_inner().await; inner.read_command_result(port, outdata, cmd.has_return_value()).await @@ -817,8 +813,8 @@ impl<'a, M: RawMutex, B: I2c> Tps6699x<'a, M, B> { } } -impl<'a, M: RawMutex, B: I2c> interrupt::InterruptController for Tps6699x<'a, M, B> { - type Guard = InterruptGuard<'a, M, B>; +impl<'a, M: RawMutex, B: I2c> InterruptController for Tps6699x<'a, M, B> { + type Guard = interrupt::InterruptGuard<'a, M, B>; type BusError = B::Error; async fn interrupts_enabled(&self) -> Result<[bool; MAX_SUPPORTED_PORTS], Error> { @@ -829,335 +825,6 @@ impl<'a, M: RawMutex, B: I2c> interrupt::InterruptController for Tps6699x<'a, M, &mut self, enabled: [bool; MAX_SUPPORTED_PORTS], ) -> Result> { - Ok(InterruptGuard::new(self.controller, enabled)) - } -} - -pub struct Interrupt<'a, M: RawMutex, B: I2c> { - controller: &'a controller::Controller, -} - -impl<'a, M: RawMutex, B: I2c> Interrupt<'a, M, B> { - fn lock_inner(&mut self) -> impl Future>> { - self.controller.inner.lock() - } - - /// Process interrupts - pub async fn process_interrupt( - &mut self, - int: &mut impl InputPin, - ) -> Result<[IntEventBus1; MAX_SUPPORTED_PORTS], Error> { - let mut flags = self - .controller - .interrupt_waker - .try_take() - .unwrap_or([IntEventBus1::new_zero(); MAX_SUPPORTED_PORTS]); - - { - let interrupts_enabled = self.controller.interrupts_enabled(); - let mut inner = self.lock_inner().await; - - // Note: `interrupts_enabled` and `flags` are both of size MAX_SUPPORTED_PORTS and so - // will always have a 1:1 mapping. If `num_ports` ever returns a value larger than - // MAX_SUPPORTED_PORTS, `port` will simply be capped at MAX_SUPPORTED_PORTS. - for (port, (interrupt_enabled, flag)) in interrupts_enabled - .iter() - .zip(flags.iter_mut()) - .take(inner.num_ports()) - .enumerate() - { - let port_id = LocalPortId(port as u8); - - if !interrupt_enabled { - trace!("{:?}: Interrupt for disabled", port_id); - continue; - } - - match int.is_high() { - Ok(true) => { - // Early exit if checking the last port cleared the interrupt - trace!("Interrupt line is high, exiting"); - continue; - } - Err(_) => { - error!("Failed to read interrupt line"); - return PdError::Failed.into(); - } - _ => {} - } - - match with_timeout(Duration::from_millis(100), inner.clear_interrupt(port_id)).await { - Ok(res) => match res { - Ok(event) => *flag |= event, - Err(_e) => { - continue; - } - }, - Err(_) => { - error!("{:?}: clear_interrupt timeout", port_id); - continue; - } - } - } - } - - self.controller.interrupt_waker.signal(flags); - Ok(flags) - } -} - -/// Restores the original interrupt state when dropped -pub struct InterruptGuard<'a, M: RawMutex, B: I2c> { - target_state: [bool; MAX_SUPPORTED_PORTS], - controller: &'a controller::Controller, -} - -impl<'a, M: RawMutex, B: I2c> InterruptGuard<'a, M, B> { - fn new(controller: &'a controller::Controller, enabled: [bool; MAX_SUPPORTED_PORTS]) -> Self { - let target_state = controller.interrupts_enabled(); - controller.enable_interrupts(enabled); - Self { - target_state, - controller, - } - } -} - -impl Drop for InterruptGuard<'_, M, B> { - fn drop(&mut self) { - self.controller.enable_interrupts(self.target_state); - } -} - -impl interrupt::InterruptGuard for InterruptGuard<'_, M, B> {} - -/// Struct to ensure drop-safety of [`Tps6699x::wait_interrupt_any`] -/// -/// This struct re-signals any unhandled interrupts on drop. -struct AccumulatedFlagsAny<'a, M: RawMutex, B: I2c> { - controller: &'a controller::Controller, - accumulated_flags: [IntEventBus1; MAX_SUPPORTED_PORTS], - masks: [IntEventBus1; MAX_SUPPORTED_PORTS], -} - -impl<'a, M: RawMutex, B: I2c> AccumulatedFlagsAny<'a, M, B> { - fn new(controller: &'a controller::Controller, masks: [IntEventBus1; MAX_SUPPORTED_PORTS]) -> Self { - AccumulatedFlagsAny { - controller, - accumulated_flags: [IntEventBus1::new_zero(); MAX_SUPPORTED_PORTS], - masks, - } - } - - fn accumulate( - &mut self, - flags: [IntEventBus1; MAX_SUPPORTED_PORTS], - ) -> Option<[IntEventBus1; MAX_SUPPORTED_PORTS]> { - let mut done = false; - for (&flags, &mask, accumulated) in izip!(flags.iter(), self.masks.iter(), self.accumulated_flags.iter_mut(),) { - *accumulated |= flags; - let consumed_flags = flags & mask; - if consumed_flags != IntEventBus1::new_zero() { - done = true; - } - } - - if done { - // Panic safety: the return type, `accumulated_flags`, and `mask` are all of size MAX_SUPPORTED_PORTS - // so this will never index out of bounds - #[allow(clippy::indexing_slicing)] - let handled = from_fn(|i| self.accumulated_flags[i] & self.masks[i]); - // Put unhandled flags back for signaling in `drop()` - self.accumulated_flags = from_fn(|i| self.accumulated_flags[i] & !self.masks[i]); - Some(handled) - } else { - None - } - } -} - -impl Drop for AccumulatedFlagsAny<'_, M, B> { - fn drop(&mut self) { - // Catch any flags that may have happened since the last accumulate. - let new = self - .controller - .interrupt_waker - .try_take() - .unwrap_or([IntEventBus1::new_zero(); MAX_SUPPORTED_PORTS]); - // Panic safety: `unhandled`, `accumulated_flags`, and `mask` are all of size MAX_SUPPORTED_PORTS - // so this will never index out of bounds - #[allow(clippy::indexing_slicing)] - let unhandled = from_fn(|i| self.accumulated_flags[i] | new[i]); - - // Put back any unhandled interrupt flags for future processing - if unhandled.iter().any(|&f| f != IntEventBus1::new_zero()) { - // If there are unhandled flags, signal them for future processing - trace!("Signaling unhandled interrupt flags: {:?}", unhandled); - self.controller.interrupt_waker.signal(unhandled); - } - } -} - -#[cfg(test)] -mod test { - use embassy_sync::blocking_mutex::raw::NoopRawMutex; - use embassy_time::{with_timeout, Duration, TimeoutError}; - use embedded_hal_mock::eh1::i2c::Mock; - use static_cell::StaticCell; - - use super::*; - use crate::asynchronous::embassy::controller::Controller; - use crate::ADDR0; - - /// Tests `wait_interrupt_any` with a mask for both ports. - #[tokio::test] - async fn test_wait_interrupt_any_both() { - static CONTROLLER: StaticCell> = StaticCell::new(); - let controller = CONTROLLER.init(controller::Controller::new_tps66994(Mock::new(&[]), ADDR0).unwrap()); - let (mut pd, _interrupt) = controller.make_parts(); - - let mut port0 = IntEventBus1::new_zero(); - port0.set_new_consumer_contract(true); - port0.set_sink_ready(true); - port0.set_cmd_1_completed(true); - - let mut port1 = IntEventBus1::new_zero(); - port1.set_plug_event(true); - port1.set_alert_message_received(true); - - pd.controller.interrupt_waker.signal([port0, port1]); - - let mut mask0 = IntEventBus1::new_zero(); - mask0.set_cmd_1_completed(true); - - let mut mask1 = IntEventBus1::new_zero(); - mask1.set_plug_event(true); - mask1.set_alert_message_received(true); - - let flags = pd.wait_interrupt_any(false, [mask0, mask1]).await; - assert_eq!(flags, [mask0, mask1]); - - let mut unhandled0 = IntEventBus1::new_zero(); - unhandled0.set_new_consumer_contract(true); - unhandled0.set_sink_ready(true); - - let unhandled1 = IntEventBus1::new_zero(); - - // Should already be signaled - assert_eq!( - pd.controller.interrupt_waker.try_take().unwrap(), - [unhandled0, unhandled1] - ); - } - - /// Tests `wait_interrupt` with a mask for a single port. - #[tokio::test] - async fn test_wait_interrupt_any_single() { - static CONTROLLER: StaticCell> = StaticCell::new(); - let controller = CONTROLLER.init(controller::Controller::new_tps66994(Mock::new(&[]), ADDR0).unwrap()); - let (mut pd, _interrupt) = controller.make_parts(); - - let mut port0 = IntEventBus1::new_zero(); - port0.set_new_consumer_contract(true); - port0.set_sink_ready(true); - port0.set_cmd_1_completed(true); - - let mut port1 = IntEventBus1::new_zero(); - port1.set_plug_event(true); - port1.set_alert_message_received(true); - - pd.controller.interrupt_waker.signal([port0, port1]); - - let mut mask0 = IntEventBus1::new_zero(); - mask0.set_cmd_1_completed(true); - - let mask1 = IntEventBus1::new_zero(); - - let flags = pd.wait_interrupt_any(false, [mask0, mask1]).await; - assert_eq!(flags, [mask0, mask1]); - - let mut unhandled0 = IntEventBus1::new_zero(); - unhandled0.set_new_consumer_contract(true); - unhandled0.set_sink_ready(true); - - let unhandled1 = port1; - - // Should already be signaled - assert_eq!( - pd.controller.interrupt_waker.try_take().unwrap(), - [unhandled0, unhandled1] - ); - } - - /// Tests `wait_interrupt` with both masks set to zero. - #[tokio::test] - async fn test_wait_interrupt_any_zero_masks() { - static CONTROLLER: StaticCell> = StaticCell::new(); - let controller = CONTROLLER.init(controller::Controller::new_tps66994(Mock::new(&[]), ADDR0).unwrap()); - let (mut pd, _interrupt) = controller.make_parts(); - - let mut port0 = IntEventBus1::new_zero(); - port0.set_new_consumer_contract(true); - port0.set_sink_ready(true); - port0.set_cmd_1_completed(true); - - let mut port1 = IntEventBus1::new_zero(); - port1.set_plug_event(true); - port1.set_alert_message_received(true); - - pd.controller.interrupt_waker.signal([port0, port1]); - - let mask0 = IntEventBus1::new_zero(); - let mask1 = IntEventBus1::new_zero(); - let flags = pd.wait_interrupt_any(false, [mask0, mask1]).await; - assert_eq!(flags, [mask0, mask1]); - - // Should already be signaled with nothing changed - assert_eq!(pd.controller.interrupt_waker.try_take().unwrap(), [port0, port1]); - } - - #[tokio::test] - async fn test_wait_interrupt_any_timeout() { - // Port0 mocked pending interrupts - let mut port0 = IntEventBus1::new_zero(); - port0.set_new_consumer_contract(true); - - // Port1 mocked pending interrupts - let mut port1 = IntEventBus1::new_zero(); - port1.set_plug_event(true); - - static CONTROLLER: StaticCell> = StaticCell::new(); - let controller = CONTROLLER.init(Controller::new_tps66994(Mock::new(&[]), ADDR0).unwrap()); - let (mut pd, _interrupt) = controller.make_parts(); - - pd.controller.interrupt_waker.signal([port0, port1]); - - // The mask doesn't match the pending interrupts, so we should get a timeout - let mut mask0 = IntEventBus1::new_zero(); - mask0.set_cmd_1_completed(true); - - let mut mask1 = IntEventBus1::new_zero(); - mask1.set_new_provider_contract(true); - - assert_eq!( - with_timeout(Duration::from_millis(10), pd.wait_interrupt_any(false, [mask0, mask1])).await, - Err(TimeoutError) - ); - - // Use all mask to get leftover interrupts - let mut leftover0 = IntEventBus1::new_zero(); - leftover0.set_new_consumer_contract(true); - - let mut leftover1 = IntEventBus1::new_zero(); - leftover1.set_plug_event(true); - - let leftover_flags = with_timeout( - Duration::from_millis(10), - pd.wait_interrupt_any(false, [IntEventBus1::all(), IntEventBus1::all()]), - ) - .await - .unwrap(); - assert_eq!(leftover_flags[0], leftover0); - assert_eq!(leftover_flags[1], leftover1); + Ok(interrupt::InterruptGuard::new(self.controller, enabled)) } } diff --git a/src/asynchronous/embassy/task.rs b/src/asynchronous/embassy/task.rs index 2a5889d..f56f138 100644 --- a/src/asynchronous/embassy/task.rs +++ b/src/asynchronous/embassy/task.rs @@ -3,13 +3,13 @@ use embedded_hal::digital::InputPin; use embedded_hal_async::digital::Wait; use embedded_hal_async::i2c::I2c; -use super::Interrupt; +use super::interrupt::InterruptProcessor; use crate::{error, trace, warn}; /// Task to process all given interrupts pub async fn interrupt_task( int: &mut INT, - interrupts: &mut [&mut Interrupt<'_, M, B>], + interrupts: &mut [&mut InterruptProcessor<'_, M, B>], ) { let mut retry_strategy = retry_strategy::ExponentialBackoff::default(); loop {