Shortcuts

Source code for ignite.handlers.checkpoint

import collections.abc as collections
import numbers
import os
import tempfile
import warnings
from abc import ABCMeta, abstractmethod
from collections import OrderedDict, namedtuple
from tempfile import _TemporaryFileWrapper  # type: ignore
from typing import Callable, Mapping, Optional, Union

import torch
import torch.nn as nn

import ignite.distributed as idist
from ignite.base import Serializable
from ignite.engine import Engine, Events

__all__ = ["Checkpoint", "DiskSaver", "ModelCheckpoint", "BaseSaveHandler"]


[docs]class BaseSaveHandler(metaclass=ABCMeta): """Base class for save handlers Methods to override: - :meth:`~ignite.handlers.checkpoint.BaseSaveHandler.__call__` - :meth:`~ignite.handlers.checkpoint.BaseSaveHandler.remove` Note: In derived class, please, make sure that in distributed configuration overridden methods are called by a single process. Distributed configuration on XLA devices should be treated slightly differently: for saving checkpoint with `xm.save() <https://pytorch.org/xla/release/1.5/index.html#torch_xla.core.xla_model.save>`_ all processes should pass into the function. Otherwise, application gets stuck. """
[docs] @abstractmethod def __call__(self, checkpoint: Mapping, filename: str, metadata: Optional[Mapping] = None) -> None: """Method to save `checkpoint` with `filename`. Additionally, metadata dictionary is provided. Metadata contains: - `basename`: file prefix (if provided) with checkpoint name, e.g. `epoch_checkpoint`. - `score_name`: score name if provided, e.g `val_acc`. - `priority`: checkpoint priority value (higher is better), e.g. `12` or `0.6554435` Args: checkpoint (Mapping): checkpoint dictionary to save. filename (str): filename associated with checkpoint. metadata (Mapping, optional): metadata on checkpoint to save. """ pass
[docs] @abstractmethod def remove(self, filename: str) -> None: """Method to remove saved checkpoint. Args: filename (str): filename associated with checkpoint. """ pass
[docs]class Checkpoint(Serializable): """Checkpoint handler can be used to periodically save and load objects which have attribute ``state_dict`/`load_state_dict``. This class can use specific save handlers to store on the disk or a cloud storage, etc. The Checkpoint handler (if used with :class:`~ignite.handlers.DiskSaver`) also handles automatically moving data on TPU to CPU before writing the checkpoint. Args: to_save (Mapping): Dictionary with the objects to save. Objects should have implemented ``state_dict`` and ``load_state_dict`` methods. If contains objects of type torch `DistributedDataParallel`_ or `DataParallel`_, their internal wrapped model is automatically saved (to avoid additional key ``module.`` in the state dictionary). save_handler (callable or :class:`~ignite.handlers.checkpoint.BaseSaveHandler`): Method or callable class to use to save engine and other provided objects. Function receives two objects: checkpoint as a dictionary and filename. If ``save_handler`` is callable class, it can inherit of :class:`~ignite.handlers.checkpoint.BaseSaveHandler` and optionally implement ``remove`` method to keep a fixed number of saved checkpoints. In case if user needs to save engine's checkpoint on a disk, ``save_handler`` can be defined with :class:`~ignite.handlers.DiskSaver`. filename_prefix (str, optional): Prefix for the file name to which objects will be saved. See Note for details. score_function (callable, optional): If not None, it should be a function taking a single argument, :class:`~ignite.engine.engine.Engine` object, and returning a score (`float`). Objects with highest scores will be retained. score_name (str, optional): If ``score_function`` not None, it is possible to store its value using ``score_name``. See Notes for more details. n_saved (int, optional): Number of objects that should be kept on disk. Older files will be removed. If set to `None`, all objects are kept. global_step_transform (callable, optional): global step transform function to output a desired global step. Input of the function is ``(engine, event_name)``. Output of function should be an integer. Default is None, global_step based on attached engine. If provided, uses function output as global_step. To setup global step from another engine, please use :meth:`~ignite.handlers.global_step_from_engine`. archived (bool, optional): Deprecated argument as models saved by ``torch.save`` are already compressed. filename_pattern (str, optional): If ``filename_pattern`` is provided, this pattern will be used to render checkpoint filenames. If the pattern is not defined, the default pattern would be used. See Note for details. include_self (bool): Whether to include the `state_dict` of this object in the checkpoint. If `True`, then there must not be another object in ``to_save`` with key ``checkpointer``. .. _DistributedDataParallel: https://pytorch.org/docs/stable/generated/ torch.nn.parallel.DistributedDataParallel.html .. _DataParallel: https://pytorch.org/docs/stable/generated/torch.nn.DataParallel.html Note: This class stores a single file as a dictionary of provided objects to save. The filename is defined by ``filename_pattern`` and by default has the following structure: ``{filename_prefix}_{name}_{suffix}.{ext}`` where - ``filename_prefix`` is the argument passed to the constructor, - `name` is the key in ``to_save`` if a single object is to store, otherwise `name` is "checkpoint". - `suffix` is composed as following ``{global_step}_{score_name}={score}``. +----------------+------------+-----------------------+----------------------------------------------+ | score_function | score_name | global_step_transform | suffix | +================+============+=======================+==============================================+ | None | None | None | ``{engine.state.iteration}`` | +----------------+------------+-----------------------+----------------------------------------------+ | X | None | None | ``{score}`` | +----------------+------------+-----------------------+----------------------------------------------+ | X | None | X | ``{global_step}_{score}`` | +----------------+------------+-----------------------+----------------------------------------------+ | X | X | X | ``{global_step}_{score_name}={score}`` | +----------------+------------+-----------------------+----------------------------------------------+ | None | None | X | ``{global_step}`` | +----------------+------------+-----------------------+----------------------------------------------+ | X | X | None | ``{score_name}={score}`` | +----------------+------------+-----------------------+----------------------------------------------+ Above `global_step` defined by the output of `global_step_transform` and `score` defined by the output of `score_function`. By default, none of ``score_function``, ``score_name``, ``global_step_transform`` is defined, then suffix is setup by attached engine's current iteration. The filename will be `{filename_prefix}_{name}_{engine.state.iteration}.{ext}`. For example, ``score_name="neg_val_loss"`` and ``score_function`` that returns `-loss` (as objects with highest scores will be retained), then saved filename will be ``{filename_prefix}_{name}_neg_val_loss=-0.1234.pt``. Note: If ``filename_pattern`` is given, it will be used to render the filenames. ``filename_pattern`` is a string that can contain ``{filename_prefix}``, ``{name}``, ``{score}``, ``{score_name}`` and ``{global_step}`` as templates. For example, let ``filename_pattern="{global_step}-{name}-{score}.pt"`` then the saved filename will be ``30000-checkpoint-94.pt`` **Warning:** Please, keep in mind that if filename collide with already used one to saved a checkpoint, new checkpoint will not be stored. This means that filename like ``checkpoint.pt`` will be saved only once and will not be overwritten by newer checkpoints. Note: To get the last stored filename, handler exposes attribute ``last_checkpoint``: .. code-block:: python handler = Checkpoint(...) ... print(handler.last_checkpoint) > checkpoint_12345.pt Note: This class is distributed configuration-friendly: it is not required to instantiate the class in rank 0 only process. This class supports automatically distributed configuration and if used with :class:`~ignite.handlers.DiskSaver`, checkpoint is stored by rank 0 process. .. warning:: When running on XLA devices, it should be run in all processes, otherwise application can get stuck on saving the checkpoint. .. code-block:: python # Wrong: # if idist.get_rank() == 0: # handler = Checkpoint(...) # trainer.add_event_handler(Events.ITERATION_COMPLETED(every=1000), handler) # Correct: handler = Checkpoint(...) trainer.add_event_handler(Events.ITERATION_COMPLETED(every=1000), handler) Examples: Attach the handler to make checkpoints during training: .. code-block:: python from ignite.engine import Engine, Events from ignite.handlers import Checkpoint, DiskSaver trainer = ... model = ... optimizer = ... lr_scheduler = ... to_save = {'model': model, 'optimizer': optimizer, 'lr_scheduler': lr_scheduler, 'trainer': trainer} handler = Checkpoint(to_save, DiskSaver('/tmp/models', create_dir=True), n_saved=2) trainer.add_event_handler(Events.ITERATION_COMPLETED(every=1000), handler) trainer.run(data_loader, max_epochs=6) > ["checkpoint_7000.pt", "checkpoint_8000.pt", ] Attach the handler to an evaluator to save best model during the training according to computed validation metric: .. code-block:: python from ignite.engine import Engine, Events from ignite.handlers import Checkpoint, DiskSaver, global_step_from_engine trainer = ... evaluator = ... # Setup Accuracy metric computation on evaluator # Run evaluation on epoch completed event # ... def score_function(engine): return engine.state.metrics['accuracy'] to_save = {'model': model} handler = Checkpoint(to_save, DiskSaver('/tmp/models', create_dir=True), n_saved=2, filename_prefix='best', score_function=score_function, score_name="val_acc", global_step_transform=global_step_from_engine(trainer)) evaluator.add_event_handler(Events.COMPLETED, handler) trainer.run(data_loader, max_epochs=10) > ["best_model_9_val_acc=0.77.pt", "best_model_10_val_acc=0.78.pt", ] """ Item = namedtuple("Item", ["priority", "filename"]) _state_dict_all_req_keys = ("saved",) def __init__( self, to_save: Optional[Mapping], save_handler: Union[Callable, BaseSaveHandler], filename_prefix: str = "", score_function: Optional[Callable] = None, score_name: Optional[str] = None, n_saved: Optional[int] = 1, global_step_transform: Callable = None, archived: bool = False, filename_pattern: Optional[str] = None, include_self: bool = False, ): if to_save is not None: # for compatibility with ModelCheckpoint if not isinstance(to_save, collections.Mapping): raise TypeError("Argument `to_save` should be a dictionary, but given {}".format(type(to_save))) if len(to_save) < 1: raise ValueError("No objects to checkpoint.") self._check_objects(to_save, "state_dict") if include_self: if not isinstance(to_save, collections.MutableMapping): raise TypeError( "If `include_self` is True, then `to_save` must be mutable, but given {}.".format(type(to_save)) ) if "checkpointer" in to_save: raise ValueError("Cannot have key 'checkpointer' if `include_self` is True: {}".format(to_save)) if not (callable(save_handler) or isinstance(save_handler, BaseSaveHandler)): raise TypeError("Argument `save_handler` should be callable or inherit from BaseSaveHandler") if score_function is None and score_name is not None: raise ValueError("If `score_name` is provided, then `score_function` " "should be also provided.") if global_step_transform is not None and not callable(global_step_transform): raise TypeError( "global_step_transform should be a function, got {} instead.".format(type(global_step_transform)) ) if archived: warnings.warn("Argument archived is deprecated and will be removed in 0.5.0") self.to_save = to_save self.filename_prefix = filename_prefix self.save_handler = save_handler self.score_function = score_function self.score_name = score_name self.n_saved = n_saved self.ext = "pt" self.global_step_transform = global_step_transform self.filename_pattern = filename_pattern self._saved = [] # type: list self.include_self = include_self @property def last_checkpoint(self) -> Optional[str]: if len(self._saved) < 1: return None return self._saved[-1].filename def _check_lt_n_saved(self, or_equal=False): if self.n_saved is None: return True return len(self._saved) < self.n_saved + int(or_equal) def __call__(self, engine: Engine) -> None: global_step = None if self.global_step_transform is not None: global_step = self.global_step_transform(engine, engine.last_event_name) if self.score_function is not None: priority = self.score_function(engine) if not isinstance(priority, numbers.Number): raise ValueError("Output of score_function should be a number") else: if global_step is None: global_step = engine.state.get_event_attrib_value(Events.ITERATION_COMPLETED) priority = global_step if self._check_lt_n_saved() or self._saved[0].priority < priority: priority_str = ( "{}".format(priority) if isinstance(priority, numbers.Integral) else "{:.4f}".format(priority) ) checkpoint = self._setup_checkpoint() name = "checkpoint" if len(checkpoint) == 1: for k in checkpoint: name = k checkpoint = checkpoint[name] if self.filename_pattern is None: filename_pattern = self.setup_filename_pattern( with_prefix=len(self.filename_prefix) > 0, with_score=self.score_function is not None, with_score_name=self.score_name is not None, with_global_step=global_step is not None, ) else: filename_pattern = self.filename_pattern filename_dict = { "filename_prefix": self.filename_prefix, "ext": self.ext, "name": name, "score_name": self.score_name, "score": priority_str if (self.score_function is not None) else None, "global_step": global_step, } filename = filename_pattern.format(**filename_dict) if any(item.filename == filename for item in self._saved): return metadata = { "basename": "{}{}{}".format(self.filename_prefix, "_" * int(len(self.filename_prefix) > 0), name), "score_name": self.score_name, "priority": priority, } if not self._check_lt_n_saved(): item = self._saved.pop(0) if isinstance(self.save_handler, BaseSaveHandler): self.save_handler.remove(item.filename) self._saved.append(Checkpoint.Item(priority, filename)) self._saved.sort(key=lambda item: item[0]) if self.include_self: # Now that we've updated _saved, we can add our own state_dict. checkpoint["checkpointer"] = self.state_dict() try: self.save_handler(checkpoint, filename, metadata) except TypeError: self.save_handler(checkpoint, filename) def _setup_checkpoint(self) -> dict: checkpoint = {} if self.to_save is not None: for k, obj in self.to_save.items(): if isinstance(obj, (nn.DataParallel, nn.parallel.DistributedDataParallel)): obj = obj.module checkpoint[k] = obj.state_dict() return checkpoint @staticmethod def setup_filename_pattern( with_prefix: bool = True, with_score: bool = True, with_score_name: bool = True, with_global_step: bool = True, ) -> str: """Helper method to get the default filename pattern for a checkpoint. Args: with_prefix (bool): If True, the ``filename_prefix`` is added to the filename pattern: ``{filename_prefix}_{name}...``. Default, True. with_score (bool): If True, ``score`` is added to the filename pattern: ``..._{score}.{ext}``. Default, True. At least one of ``with_score`` and ``with_global_step`` should be True. with_score_name (bool): If True, ``score_name`` is added to the filename pattern: ``..._{score_name}={score}.{ext}``. If activated, argument ``with_score`` should be also True, otherwise an error is raised. Default, True. with_global_step (bool): If True, ``{global_step}`` is added to the filename pattern: ``...{name}_{global_step}...``. At least one of ``with_score`` and ``with_global_step`` should be True. Example: .. code-block:: python from ignite.handlers import Checkpoint filename_pattern = Checkpoint.setup_filename_pattern() print(filename_pattern) > "{filename_prefix}_{name}_{global_step}_{score_name}={score}.{ext}" """ filename_pattern = "{name}" if not (with_global_step or with_score): raise ValueError("At least one of with_score and with_global_step should be True.") if with_global_step: filename_pattern += "_{global_step}" if with_score_name and with_score: filename_pattern += "_{score_name}={score}" elif with_score: filename_pattern += "_{score}" elif with_score_name: raise ValueError("If with_score_name is True, with_score should be also True") if with_prefix: filename_pattern = "{filename_prefix}_" + filename_pattern filename_pattern += ".{ext}" return filename_pattern @staticmethod def _check_objects(objs: Mapping, attr: str) -> None: for k, obj in objs.items(): if not hasattr(obj, attr): raise TypeError("Object {} should have `{}` method".format(type(obj), attr))
[docs] @staticmethod def load_objects(to_load: Mapping, checkpoint: Mapping, **kwargs) -> None: """Helper method to apply ``load_state_dict`` on the objects from ``to_load`` using states from ``checkpoint``. Exemples: .. code-block:: python import torch from ignite.engine import Engine, Events from ignite.handlers import ModelCheckpoint, Checkpoint trainer = Engine(lambda engine, batch: None) handler = ModelCheckpoint('/tmp/models', 'myprefix', n_saved=None, create_dir=True) model = torch.nn.Linear(3, 3) optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) to_save = {"weights": model, "optimizer": optimizer} trainer.add_event_handler(Events.EPOCH_COMPLETED(every=2), handler, to_save) trainer.run(torch.randn(10, 1), 5) to_load = to_save checkpoint_fp = "/tmp/models/myprefix_checkpoint_40.pth" checkpoint = torch.load(checkpoint_fp) Checkpoint.load_objects(to_load=to_load, checkpoint=checkpoint) Note: If ``to_load`` contains objects of type torch `DistributedDataParallel`_ or `DataParallel`_, method ``load_state_dict`` will applied to their internal wrapped model (``obj.module``). Args: to_load (Mapping): a dictionary with objects, e.g. `{"model": model, "optimizer": optimizer, ...}` checkpoint (Mapping): a dictionary with state_dicts to load, e.g. `{"model": model_state_dict, "optimizer": opt_state_dict}`. If `to_load` contains a single key, then checkpoint can contain directly corresponding state_dict. **kwargs: Keyword arguments accepted for `nn.Module.load_state_dict()`. Passing `strict=False` enables the user to load part of the pretrained model (useful for example, in Transfer Learning) .. _DistributedDataParallel: https://pytorch.org/docs/stable/generated/ torch.nn.parallel.DistributedDataParallel.html .. _DataParallel: https://pytorch.org/docs/stable/generated/torch.nn.DataParallel.html """ Checkpoint._check_objects(to_load, "load_state_dict") if not isinstance(checkpoint, collections.Mapping): raise TypeError("Argument checkpoint should be a dictionary, but given {}".format(type(checkpoint))) if len(kwargs) > 1 or any(k for k in kwargs.keys() if k not in ["strict"]): warnings.warn("kwargs contains keys other than strict and these will be ignored") is_state_dict_strict = kwargs.get("strict", True) if len(to_load) == 1: # single object and checkpoint is directly a state_dict key, obj = list(to_load.items())[0] if key not in checkpoint: if isinstance(obj, (nn.DataParallel, nn.parallel.DistributedDataParallel)): obj = obj.module obj.load_state_dict(checkpoint, strict=is_state_dict_strict) return # multiple objects to load for k, obj in to_load.items(): if k not in checkpoint: raise ValueError("Object labeled by '{}' from `to_load` is not found in the checkpoint".format(k)) if isinstance(obj, (nn.DataParallel, nn.parallel.DistributedDataParallel)): obj = obj.module if isinstance(obj, torch.nn.Module): obj.load_state_dict(checkpoint[k], strict=is_state_dict_strict) else: obj.load_state_dict(checkpoint[k])
def state_dict(self) -> OrderedDict: return OrderedDict([("saved", [(p, f) for p, f in self._saved])]) def load_state_dict(self, state_dict: Mapping) -> None: super().load_state_dict(state_dict) self._saved = [Checkpoint.Item(p, f) for p, f in state_dict["saved"]]
[docs]class DiskSaver(BaseSaveHandler): """Handler that saves input checkpoint on a disk. Args: dirname (str): Directory path where the checkpoint will be saved atomic (bool, optional): if True, checkpoint is serialized to a temporary file, and then moved to final destination, so that files are guaranteed to not be damaged (for example if exception occurs during saving). create_dir (bool, optional): if True, will create directory ``dirname`` if it doesnt exist. require_empty (bool, optional): If True, will raise exception if there are any files in the directory ``dirname``. **kwargs: Accepted keyword arguments for `torch.save` or `xm.save`. """ def __init__( self, dirname: str, atomic: bool = True, create_dir: bool = True, require_empty: bool = True, **kwargs ): self.dirname = os.path.expanduser(dirname) self._atomic = atomic self._check_and_setup(dirname, create_dir, require_empty) self.kwargs = kwargs @staticmethod @idist.one_rank_only() def _check_and_setup(dirname, create_dir, require_empty): if create_dir: if not os.path.exists(dirname): os.makedirs(dirname) # Ensure that dirname exists if not os.path.exists(dirname): raise ValueError("Directory path '{}' is not found".format(dirname)) if require_empty: matched = [fname for fname in os.listdir(dirname) if fname.endswith(".pt")] if len(matched) > 0: raise ValueError( "Files {} with extension '.pt' are already present " "in the directory {}. If you want to use this " "directory anyway, pass `require_empty=False`." "".format(matched, dirname) ) def __call__(self, checkpoint: Mapping, filename: str, metadata: Optional[Mapping] = None) -> None: path = os.path.join(self.dirname, filename) if idist.has_xla_support: self._save_xla(checkpoint, path) else: self._save_native(checkpoint, path) @idist.one_rank_only() def _save_native(self, checkpoint: Mapping, path: str): self._save_func(checkpoint, path, torch.save) def _save_xla(self, checkpoint: Mapping, path: str): import torch_xla.core.xla_model as xm # type: ignore # all tpu procs should enter here as internally performs sync across device self._save_func(checkpoint, path, xm.save, rank=idist.get_rank()) def _save_func(self, checkpoint: Mapping, path: str, func: Callable, rank: int = 0): if not self._atomic: func(checkpoint, path, **self.kwargs) else: tmp_file = None tmp_name = "" tmp = None # type: _TemporaryFileWrapper if rank == 0: tmp = tempfile.NamedTemporaryFile(delete=False, dir=self.dirname) tmp_file = tmp.file tmp_name = tmp.name try: func(checkpoint, tmp_file, **self.kwargs) except BaseException: if tmp is not None: tmp.close() os.remove(tmp_name) raise else: if tmp is not None: tmp.close() os.rename(tmp.name, path) @idist.one_rank_only() def remove(self, filename: str) -> None: path = os.path.join(self.dirname, filename) os.remove(path)
[docs]class ModelCheckpoint(Checkpoint): """ModelCheckpoint handler can be used to periodically save objects to disk only. If needed to store checkpoints to another storage type, please consider :class:`~ignite.handlers.checkpoint.Checkpoint`. This handler expects two arguments: - an :class:`~ignite.engine.engine.Engine` object - a `dict` mapping names (`str`) to objects that should be saved to disk. See Examples for further details. .. warning:: Behaviour of this class has been changed since v0.3.0. Argument ``save_as_state_dict`` is deprecated and should not be used. It is considered as True. Argument ``save_interval`` is deprecated and should not be used. Please, use events filtering instead, e.g. :attr:`~ignite.engine.events.Events.ITERATION_STARTED(every=1000)` There is no more internal counter that has been used to indicate the number of save actions. User could see its value `step_number` in the filename, e.g. `{filename_prefix}_{name}_{step_number}.pt`. Actually, `step_number` is replaced by current engine's epoch if `score_function` is specified and current iteration otherwise. A single `pt` file is created instead of multiple files. Args: dirname (str): Directory path where objects will be saved. filename_prefix (str): Prefix for the file names to which objects will be saved. See Notes of :class:`~ignite.handlers.Checkpoint` for more details. score_function (callable, optional): if not None, it should be a function taking a single argument, an :class:`~ignite.engine.engine.Engine` object, and return a score (`float`). Objects with highest scores will be retained. score_name (str, optional): if ``score_function`` not None, it is possible to store its value using `score_name`. See Notes for more details. n_saved (int, optional): Number of objects that should be kept on disk. Older files will be removed. If set to `None`, all objects are kept. atomic (bool, optional): If True, objects are serialized to a temporary file, and then moved to final destination, so that files are guaranteed to not be damaged (for example if exception occurs during saving). require_empty (bool, optional): If True, will raise exception if there are any files starting with ``filename_prefix`` in the directory ``dirname``. create_dir (bool, optional): If True, will create directory ``dirname`` if it does not exist. global_step_transform (callable, optional): global step transform function to output a desired global step. Input of the function is `(engine, event_name)`. Output of function should be an integer. Default is None, global_step based on attached engine. If provided, uses function output as global_step. To setup global step from another engine, please use :meth:`~ignite.handlers.global_step_from_engine`. archived (bool, optional): Deprecated argument as models saved by `torch.save` are already compressed. include_self (bool): Whether to include the `state_dict` of this object in the checkpoint. If `True`, then there must not be another object in ``to_save`` with key ``checkpointer``. **kwargs: Accepted keyword arguments for `torch.save` or `xm.save` in `DiskSaver`. Examples: >>> import os >>> from ignite.engine import Engine, Events >>> from ignite.handlers import ModelCheckpoint >>> from torch import nn >>> trainer = Engine(lambda engine, batch: None) >>> handler = ModelCheckpoint('/tmp/models', 'myprefix', n_saved=2, create_dir=True) >>> model = nn.Linear(3, 3) >>> trainer.add_event_handler(Events.EPOCH_COMPLETED(every=2), handler, {'mymodel': model}) >>> trainer.run([0], max_epochs=6) >>> os.listdir('/tmp/models') ['myprefix_mymodel_4.pt', 'myprefix_mymodel_6.pt'] >>> handler.last_checkpoint ['/tmp/models/myprefix_mymodel_6.pt'] """ def __init__( self, dirname: str, filename_prefix: str, save_interval: Optional[Callable] = None, score_function: Optional[Callable] = None, score_name: Optional[str] = None, n_saved: Union[int, None] = 1, atomic: bool = True, require_empty: bool = True, create_dir: bool = True, save_as_state_dict: bool = True, global_step_transform: Optional[Callable] = None, archived: bool = False, include_self: bool = False, **kwargs ): if not save_as_state_dict: raise ValueError( "Argument save_as_state_dict is deprecated and should be True." "This argument will be removed in 0.5.0." ) if save_interval is not None: msg = ( "Argument save_interval is deprecated and should be None. This argument will be removed in 0.5.0." "Please, use events filtering instead, e.g. Events.ITERATION_STARTED(every=1000)" ) if save_interval == 1: # Do not break for old version who used `save_interval=1` warnings.warn(msg) else: # No choice raise ValueError(msg) disk_saver = DiskSaver(dirname, atomic=atomic, create_dir=create_dir, require_empty=require_empty, **kwargs) super(ModelCheckpoint, self).__init__( to_save=None, save_handler=disk_saver, filename_prefix=filename_prefix, score_function=score_function, score_name=score_name, n_saved=n_saved, global_step_transform=global_step_transform, archived=archived, include_self=include_self, ) @property def last_checkpoint(self) -> Union[str, None]: if len(self._saved) < 1: return None if not isinstance(self.save_handler, DiskSaver): raise RuntimeError( "Unable to save checkpoint, save_handler should be DiskSaver, got {}.".format(type(self.save_handler)) ) return os.path.join(self.save_handler.dirname, self._saved[-1].filename) def __call__(self, engine: Engine, to_save: Mapping) -> None: # type: ignore if len(to_save) == 0: raise RuntimeError("No objects to checkpoint found.") self._check_objects(to_save, "state_dict") self.to_save = to_save super(ModelCheckpoint, self).__call__(engine)