BindsNET学习系列——BasePipeline

相关源码:bindsnet/bindsnet/pipeline/base_pipeline.py

class BasePipeline:
    # language=rst
    """
    A generic pipeline that handles high level functionality.
    """

    def __init__(self, network: Network, **kwargs) -> None:
        # language=rst
        """
        Initializes the pipeline.

        :param network: Arbitrary network object, will be managed by the
            ``BasePipeline`` class.

        Keyword arguments:

        :param int save_interval: How often to save the network to disk.
        :param str save_dir: Directory to save network object to.
        :param Dict[str, Any] plot_config: Dict containing the plot configuration.
            Includes length, type (``"color"`` or ``"line"``), and interval per plot
            type.
        :param int print_interval: Interval to print text output.
        :param bool allow_gpu: Allows automatic transfer to the GPU.
        """
        self.network = network

        # Network saving handles caching of intermediate results.
        self.save_dir = kwargs.get("save_dir", "network.pt")
        self.save_interval = kwargs.get("save_interval", None)

        # Handles plotting of all layer spikes and voltages.
        # This constructs monitors at every level.
        self.plot_config = kwargs.get(
            "plot_config", {"data_step": True, "data_length": 100}
        )

        if self.plot_config["data_step"] is not None:
            for l in self.network.layers:
                self.network.add_monitor(
                    Monitor(
                        self.network.layers[l], "s", self.plot_config["data_length"]
                    ),
                    name=f"{l}_spikes",
                )
                if hasattr(self.network.layers[l], "v"):
                    self.network.add_monitor(
                        Monitor(
                            self.network.layers[l], "v", self.plot_config["data_length"]
                        ),
                        name=f"{l}_voltages",
                    )

        self.print_interval = kwargs.get("print_interval", None)
        self.test_interval = kwargs.get("test_interval", None)
        self.step_count = 0
        self.init_fn()
        self.clock = time.time()
        self.allow_gpu = kwargs.get("allow_gpu", True)

        if torch.cuda.is_available() and self.allow_gpu:
            self.device = torch.device("cuda")
        else:
            self.device = torch.device("cpu")

        self.network.to(self.device)

    def reset_state_variables(self) -> None:
        # language=rst
        """
        Reset the pipeline.
        """
        self.network.reset_state_variables()
        self.step_count = 0

    def step(self, batch: Any, **kwargs) -> Any:
        # language=rst
        """
        Single step of any pipeline at a high level.

        :param batch: A batch of inputs to be handed to the ``step_()`` function.
                      Standard in subclasses of ``BasePipeline``.
        :return: The output from the subclass's ``step_()`` method, which could be
            anything. Passed to plotting to accommodate this.
        """
        self.step_count += 1

        batch = recursive_to(batch, self.device)
        step_out = self.step_(batch, **kwargs)

        if (
            self.print_interval is not None
            and self.step_count % self.print_interval == 0
        ):
            print(
                f"Iteration: {self.step_count} (Time: {time.time() - self.clock:.4f})"
            )
            self.clock = time.time()

        self.plots(batch, step_out)

        if self.save_interval is not None and self.step_count % self.save_interval == 0:
            self.network.save(self.save_dir)

        if self.test_interval is not None and self.step_count % self.test_interval == 0:
            self.test()

        return step_out

    def get_spike_data(self) -> Dict[str, torch.Tensor]:
        # language=rst
        """
        Get the spike data from all layers in the pipeline's network.

        :return: A dictionary containing all spike monitors from the network.
        """
        return {
            l: self.network.monitors[f"{l}_spikes"].get("s")
            for l in self.network.layers
        }

    def get_voltage_data(
        self,
    ) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]:
        # language=rst
        """
        Get the voltage data and threshold value from all applicable layers in the
        pipeline's network.

        :return: Two dictionaries containing the voltage data and threshold values from
            the network.
        """
        voltage_record = {}
        threshold_value = {}
        for l in self.network.layers:
            if hasattr(self.network.layers[l], "v"):
                voltage_record[l] = self.network.monitors[f"{l}_voltages"].get("v")
            if hasattr(self.network.layers[l], "thresh"):
                threshold_value[l] = self.network.layers[l].thresh

        return voltage_record, threshold_value

    def step_(self, batch: Any, **kwargs) -> Any:
        # language=rst
        """
        Perform a pass of the network given the input batch.

        :param batch: The current batch. This could be anything as long as the subclass
            agrees upon the format in some way.
        :return: Any output that is need for recording purposes.
        """
        raise NotImplementedError("You need to provide a step_ method.")

    def train(self) -> None:
        # language=rst
        """
        A fully self-contained training loop.
        """
        raise NotImplementedError("You need to provide a train method.")

    def test(self) -> None:
        # language=rst
        """
        A fully self contained test function.
        """
        raise NotImplementedError("You need to provide a test method.")

    def init_fn(self) -> None:
        # language=rst
        """
        Placeholder function for subclass-specific actions that need to
        happen during the construction of the ``BasePipeline``.
        """
        raise NotImplementedError("You need to provide an init_fn method.")

    def plots(self, batch: Any, step_out: Any) -> None:
        # language=rst
        """
        Create any plots and logs for a step given the input batch and step output.

        :param batch: The current batch. This could be anything as long as the subclass
            agrees upon the format in some way.
        :param step_out: The output from the ``step_()`` method.
        """
        raise NotImplementedError("You need to provide a plots method.")
原文地址:https://www.cnblogs.com/lucifer1997/p/14346152.html