-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrun_sagemaker_job.py
More file actions
44 lines (35 loc) · 937 Bytes
/
run_sagemaker_job.py
File metadata and controls
44 lines (35 loc) · 937 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import sagemaker
from sagemaker.pytorch import PyTorch
import os
# Configuration
sagemaker_session = sagemaker.Session()
# Replace with your S3 bucket
bucket = os.environ.get("S3_BUCKET")
prefix = "mnist-mlops"
# SageMaker IAM role
role = os.environ.get("SAGEMAKER_ROLE_ARN")
# Local source directory
source_dir = "./src"
entry_point = "train.py"
# Hyperparameters
hyperparameters = {
"batch_size": 100,
"num_epochs": 5,
"lr": 0.01
}
# Create PyTorch Estimator
estimator = PyTorch(
entry_point=entry_point,
source_dir=source_dir,
role=role,
framework_version="2.2.0",
py_version="py310",
instance_count=1,
instance_type="ml.m5.large",
hyperparameters=hyperparameters,
output_path=f"s3://{bucket}/{prefix}/model",
)
# Launch Training Job
print("Launching SageMaker training job...")
estimator.fit()
print(f"Training complete! Model saved to s3://{bucket}/{prefix}/model")