Influence the future of Ray with our Ray Community Pulse survey. Complete it by Monday, January 27th, 2025 to get exclusive swag for eligible participants.
#!/usr/bin/env pythonimportargparseimportjsonimportosimporttempfileimportnumpyasnpimportrayfromrayimporttrain,tunefromray.trainimportCheckpointfromray.tune.schedulersimportHyperBandSchedulerdeftrain_func(config):step=0checkpoint=train.get_checkpoint()ifcheckpoint:withcheckpoint.as_directory()ascheckpoint_dir:withopen(os.path.join(checkpoint_dir,"checkpoint.json"))asf:step=json.load(f)["timestep"]+1fortimestepinrange(step,100):v=np.tanh(float(timestep)/config.get("width",1))v*=config.get("height",1)# Checkpoint the state of the training every 3 steps# Note that this is only required for certain schedulerswithtempfile.TemporaryDirectory()astemp_checkpoint_dir:checkpoint=Noneiftimestep%3==0:withopen(os.path.join(temp_checkpoint_dir,"checkpoint.json"),"w")asf:json.dump({"timestep":timestep},f)checkpoint=Checkpoint.from_directory(temp_checkpoint_dir)# Here we use `episode_reward_mean`, but you can also report other# objectives such as loss or accuracy.train.report({"episode_reward_mean":v},checkpoint=checkpoint)if__name__=="__main__":parser=argparse.ArgumentParser()parser.add_argument("--smoke-test",action="store_true",help="Finish quickly for testing")args,_=parser.parse_known_args()ray.init(num_cpus=4ifargs.smoke_testelseNone)# Hyperband early stopping, configured with `episode_reward_mean` as the# objective and `training_iteration` as the time unit,# which is automatically filled by Tune.hyperband=HyperBandScheduler(max_t=200)tuner=tune.Tuner(train_func,run_config=train.RunConfig(name="hyperband_test",stop={"training_iteration":10ifargs.smoke_testelse99999},failure_config=train.FailureConfig(fail_fast=True,),),tune_config=tune.TuneConfig(num_samples=20,metric="episode_reward_mean",mode="max",scheduler=hyperband,),param_space={"height":tune.uniform(0,100)},)results=tuner.fit()print("Best hyperparameters found were: ",results.get_best_result().config)