Step 2: Implement the PyTorch Lightning GithubRepoRunner Component¶
The PyTorch Lightning GithubRepoRunner Component subclasses the GithubRepoRunner but tailors the execution experience to PyTorch Lightning.
As a matter of fact, this component adds two primary tailored features for PyTorch Lightning users:
It injects dynamically a custom callback
TensorboardServerLauncher
in the PyTorch Lightning Trainer to start a tensorboard server so it can be exposed in Lightning App UI.Once the script has run, the
on_after_run
hook of theTracerPythonScript
is invoked with the script globals, meaning we can collect anything we need. In particular, we are reloading the best model, torch scripting it, and storing its path in the state along side the best metric score.
Let’s dive in on how to develop the component with the following code:
class PyTorchLightningGithubRepoRunner(GithubRepoRunner):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.best_model_path = None
self.best_model_score = None
def configure_tracer(self):
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import Callback
tracer = super().configure_tracer()
class TensorboardServerLauncher(Callback):
def __init__(self, work):
# The provided `work` is the
# current ``PyTorchLightningScript`` work.
self.w = work
def on_train_start(self, trainer, *_):
# Add `host` and `port` for tensorboard to work in the cloud.
cmd = f"tensorboard --logdir='{trainer.logger.log_dir}'"
server_args = f"--host {self.w.host} --port {self.w.port}"
Popen(cmd + " " + server_args, shell=True)
def trainer_pre_fn(self, *args, work=None, **kwargs):
# Intercept Trainer __init__ call
# and inject a ``TensorboardServerLauncher`` component.
kwargs["callbacks"].append(TensorboardServerLauncher(work))
return {}, args, kwargs
# 5. Patch the `__init__` method of the Trainer
# to inject our callback with a reference to the work.
tracer.add_traced(
Trainer, "__init__", pre_fn=partial(trainer_pre_fn, work=self))
return tracer
def on_after_run(self, end_script_globals):
import torch
# 1. Once the script has finished to execute,
# we can collect its globals and access any objects.
trainer = end_script_globals["cli"].trainer
checkpoint_callback = trainer.checkpoint_callback
lightning_module = trainer.lightning_module
# 2. From the checkpoint_callback,
# we are accessing the best model weights
checkpoint = torch.load(checkpoint_callback.best_model_path)
# 3. Load the best weights and torchscript the model.
lightning_module.load_state_dict(checkpoint["state_dict"])
lightning_module.to_torchscript(f"{self.name}.pt")
# 4. Use lightning.app.storage.Pathto create a reference to the
# torch scripted model. In the cloud with multiple machines,
# by simply passing this reference to another work,
# it triggers automatically a file transfer.
self.best_model_path = Path(f"{self.name}.pt")
# 5. Keep track of the metrics.
self.best_model_score = float(checkpoint_callback.best_model_score)