Influence the future of Ray with our Ray Community Pulse survey. Complete it by Monday, January 27th, 2025 to get exclusive swag for eligible participants.
# Original Code here:# https://github.com/pytorch/examples/blob/master/mnist/main.pyfrom__future__importprint_functionimportargparseimportosimporttorchimporttorch.optimasoptimimportrayfromrayimporttrain,tunefromray.tune.examples.mnist_pytorchimport(ConvNet,get_data_loaders,test_func,train_func,)fromray.tune.schedulersimportASHAScheduler# Change these values if you want the training to run quicker or slower.EPOCH_SIZE=512TEST_SIZE=256# Training settingsparser=argparse.ArgumentParser(description="PyTorch MNIST Example")parser.add_argument("--use-gpu",action="store_true",default=False,help="enables CUDA training")parser.add_argument("--ray-address",type=str,help="The Redis address of the cluster.")parser.add_argument("--smoke-test",action="store_true",help="Finish quickly for testing")# Below comments are for documentation purposes only.# fmt: off# __trainable_example_begin__classTrainMNIST(tune.Trainable):defsetup(self,config):use_cuda=config.get("use_gpu")andtorch.cuda.is_available()self.device=torch.device("cuda"ifuse_cudaelse"cpu")self.train_loader,self.test_loader=get_data_loaders()self.model=ConvNet().to(self.device)self.optimizer=optim.SGD(self.model.parameters(),lr=config.get("lr",0.01),momentum=config.get("momentum",0.9))defstep(self):train_func(self.model,self.optimizer,self.train_loader,device=self.device)acc=test_func(self.model,self.test_loader,self.device)return{"mean_accuracy":acc}defsave_checkpoint(self,checkpoint_dir):checkpoint_path=os.path.join(checkpoint_dir,"model.pth")torch.save(self.model.state_dict(),checkpoint_path)defload_checkpoint(self,checkpoint_dir):checkpoint_path=os.path.join(checkpoint_dir,"model.pth")self.model.load_state_dict(torch.load(checkpoint_path))# __trainable_example_end__# fmt: onif__name__=="__main__":args=parser.parse_args()ray.init(address=args.ray_address,num_cpus=6ifargs.smoke_testelseNone)sched=ASHAScheduler()tuner=tune.Tuner(tune.with_resources(TrainMNIST,resources={"cpu":3,"gpu":int(args.use_gpu)}),run_config=train.RunConfig(stop={"mean_accuracy":0.95,"training_iteration":3ifargs.smoke_testelse20,},checkpoint_config=train.CheckpointConfig(checkpoint_at_end=True,checkpoint_frequency=3),),tune_config=tune.TuneConfig(metric="mean_accuracy",mode="max",scheduler=sched,num_samples=1ifargs.smoke_testelse20,),param_space={"args":args,"lr":tune.uniform(0.001,0.1),"momentum":tune.uniform(0.1,0.9),},)results=tuner.fit()print("Best config is:",results.get_best_result().config)