Influence the future of Ray with our Ray Community Pulse survey. Complete it by Monday, January 27th, 2025 to get exclusive swag for eligible participants.
Tuning Hyperparameters of a Distributed TensorFlow Model using Ray Train & Tune#
importargparseimportsysimportrayfromrayimporttunefromray.trainimportScalingConfigfromray.tune.tune_configimportTuneConfigfromray.tune.tunerimportTunerifsys.version_info>=(3,12):# Skip this test in Python 3.12+ because TensorFlow is not supported.exit(0)else:fromray.train.examples.tf.tensorflow_mnist_exampleimporttrain_funcfromray.train.tensorflowimportTensorflowTrainerdeftune_tensorflow_mnist(num_workers:int=2,num_samples:int=2,use_gpu:bool=False):trainer=TensorflowTrainer(train_loop_per_worker=train_func,scaling_config=ScalingConfig(num_workers=num_workers,use_gpu=use_gpu),)tuner=Tuner(trainer,tune_config=TuneConfig(num_samples=num_samples,metric="accuracy",mode="max"),param_space={"train_loop_config":{"lr":tune.loguniform(1e-4,1e-1),"batch_size":tune.choice([32,64,128]),"epochs":3,}},)best_accuracy=tuner.fit().get_best_result().metrics["accuracy"]print(f"Best accuracy config: {best_accuracy}")if__name__=="__main__":parser=argparse.ArgumentParser()parser.add_argument("--smoke-test",action="store_true",default=False,help="Finish quickly for testing.",)parser.add_argument("--address",required=False,type=str,help="the address to use for Ray")parser.add_argument("--num-workers","-n",type=int,default=2,help="Sets number of workers for training.",)parser.add_argument("--num-samples",type=int,default=2,help="Sets number of samples for training.",)parser.add_argument("--use-gpu",action="store_true",default=False,help="Enables GPU training")args=parser.parse_args()ifargs.smoke_test:num_gpus=args.num_workersifargs.use_gpuelse0ray.init(num_cpus=8,num_gpus=num_gpus)tune_tensorflow_mnist(num_workers=2,num_samples=2,use_gpu=args.use_gpu)else:ray.init(address=args.address)tune_tensorflow_mnist(num_workers=args.num_workers,num_samples=args.num_samples,use_gpu=args.use_gpu,)