Step 2: Implement the PyTorch Lightning GithubRepoRunner Component

The PyTorch Lightning GithubRepoRunner Component subclasses the GithubRepoRunner but tailors the execution experience to PyTorch Lightning.

As a matter of fact, this component adds two primary tailored features for PyTorch Lightning users:

  • It injects dynamically a custom callback TensorboardServerLauncher in the PyTorch Lightning Trainer to start a tensorboard server so it can be exposed in Lightning App UI.

  • Once the script has run, the on_after_run hook of the TracerPythonScript is invoked with the script globals, meaning we can collect anything we need. In particular, we are reloading the best model, torch scripting it, and storing its path in the state along side the best metric score.

Let’s dive in on how to develop the component with the following code:


class PyTorchLightningGithubRepoRunner(GithubRepoRunner):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.best_model_path = None
        self.best_model_score = None

    def configure_tracer(self):
        from lightning.pytorch import Trainer
        from lightning.pytorch.callbacks import Callback

        tracer = super().configure_tracer()

        class TensorboardServerLauncher(Callback):
            def __init__(self, work):
                # The provided `work` is the
                # current ``PyTorchLightningScript`` work.
                self.w = work

            def on_train_start(self, trainer, *_):
                # Add `host` and `port` for tensorboard to work in the cloud.
                cmd = f"tensorboard --logdir='{trainer.logger.log_dir}'"
                server_args = f"--host {self.w.host} --port {self.w.port}"
                Popen(cmd + " " + server_args, shell=True)

        def trainer_pre_fn(self, *args, work=None, **kwargs):
            # Intercept Trainer __init__ call
            # and inject a ``TensorboardServerLauncher`` component.
            kwargs["callbacks"].append(TensorboardServerLauncher(work))
            return {}, args, kwargs

        # 5. Patch the `__init__` method of the Trainer
        # to inject our callback with a reference to the work.
        tracer.add_traced(
            Trainer, "__init__", pre_fn=partial(trainer_pre_fn, work=self))
        return tracer

    def on_after_run(self, end_script_globals):
        import torch

        # 1. Once the script has finished to execute,
        # we can collect its globals and access any objects.
        trainer = end_script_globals["cli"].trainer
        checkpoint_callback = trainer.checkpoint_callback
        lightning_module = trainer.lightning_module

        # 2. From the checkpoint_callback,
        # we are accessing the best model weights
        checkpoint = torch.load(checkpoint_callback.best_model_path)

        # 3. Load the best weights and torchscript the model.
        lightning_module.load_state_dict(checkpoint["state_dict"])
        lightning_module.to_torchscript(f"{self.name}.pt")

        # 4. Use lightning.app.storage.Pathto create a reference to the
        # torch scripted model. In the cloud with multiple machines,
        # by simply passing this reference to another work,
        # it triggers automatically a file transfer.
        self.best_model_path = Path(f"{self.name}.pt")

        # 5. Keep track of the metrics.
        self.best_model_score = float(checkpoint_callback.best_model_score)