Skip to content

Commit

Permalink
Update KFTO test to add support for th disconnected environment
Browse files Browse the repository at this point in the history
  • Loading branch information
abhijeet-dhumal committed Nov 25, 2024
1 parent e4b9e14 commit c45a274
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 2 deletions.
62 changes: 60 additions & 2 deletions tests/kfto/core/kfto_training_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import (
)

func TestPyTorchJobWithCuda(t *testing.T) {
runKFTOPyTorchJob(t, GetCudaTrainingImage(), "nvidia.com/gpu", 1)
runKFTOPyTorchJob(t, GetCudaTrainingImage(), "nvidia.com/gpu", 0)
}

func TestPyTorchJobWithROCm(t *testing.T) {
Expand Down Expand Up @@ -228,7 +228,7 @@ func createKFTOPyTorchJob(test Test, namespace string, config corev1.ConfigMap,
AccessModes: []corev1.PersistentVolumeAccessMode{corev1.ReadWriteOnce},
Resources: corev1.VolumeResourceRequirements{
Requests: corev1.ResourceList{
corev1.ResourceStorage: resource.MustParse("2000Gi"),
corev1.ResourceStorage: resource.MustParse("200Gi"),
},
},
VolumeMode: Ptr(corev1.PersistentVolumeFilesystem),
Expand All @@ -253,6 +253,64 @@ func createKFTOPyTorchJob(test Test, namespace string, config corev1.ConfigMap,
},
}

storage_bucket_endpoint, storage_bucket_endpoint_exists := GetStorageBucketDefaultEndpoint()
storage_bucket_access_key_id, storage_bucket_access_key_id_exists := GetStorageBucketAccessKeyId()
storage_bucket_secret_key, storage_bucket_secret_key_exists := GetStorageBucketSecretKey()
storage_bucket_name, storage_bucket_name_exists := GetStorageBucketName()
storage_bucket_prefix, storage_bucket_prefix_exists := GetStorageBucketPrefix()

if storage_bucket_endpoint_exists {
if storage_bucket_access_key_id_exists && storage_bucket_secret_key_exists && storage_bucket_name_exists && storage_bucket_prefix_exists {
test.T().Logf("Downloading dataset from storage bucket using provided S3-credentials")
tuningJob.Spec.PyTorchReplicaSpecs["Master"].Template.Spec.InitContainers[1].Command = []string{
"/bin/sh",
"-c",
fmt.Sprintf(`pip install --target /tmp/.local datasets minio && \
HF_HOME=/tmp/.cache PYTHONPATH=/tmp/.local python -c "
import os
from minio import Minio
from datasets import load_from_disk
# S3 bucket and file details
endpoint='%s'
s3_bucket = '%s'
access_key='%s'
secret_key = '%s'
s3_prefix='%s'
local_dir = '/tmp/alpaca_dataset'
# Download dataset files from S3
os.makedirs(local_dir,exist_ok=True)
# remove prefix if specified in storage bucket endpoint url
secure = True
if endpoint.startswith('https://'):
endpoint = endpoint[len('https://') :]
elif endpoint.startswith("http://"):
endpoint = endpoint[len('http://') :]
secure = False
client = Minio(
endpoint,
access_key=access_key,
secret_key=secret_key,
cert_check=False,
secure=secure,
)
objects=client.list_objects(s3_bucket,prefix=s3_prefix, recursive=True)
for obj in objects:
local_file_path=os.path.join(local_dir, os.path.relpath(obj.object_name, s3_prefix))
os.makedirs(os.path.dirname(local_file_path),exist_ok=True)
created_obj=client.fget_object(s3_bucket,obj.object_name,local_file_path)
print('Downloaded : '+created_obj.object_name+' | Path:', os.path.join(local_file_path,obj.object_name))
# Load the dataset from the downloaded directory
dataset = load_from_disk(local_dir)
train_subset=dataset['train'][:100]
Dataset.from_dict(train_subset)
dataset.save_to_disk('/tmp/dataset')
"`, storage_bucket_endpoint, storage_bucket_name, storage_bucket_access_key_id, storage_bucket_secret_key, storage_bucket_prefix),
}
} else {
test.T().Errorf("'AWS_DEFAULT_ENDPOINT' environment variable exists, please provide 'AWS_ACCESS_KEY_ID' | 'AWS_SECRET_ACCESS_KEY' | 'AWS_STORAGE_BUCKET' | 'AWS_STORAGE_BUCKET_PREFIX'")
}
}

tuningJob, err := test.Client().Kubeflow().KubeflowV1().PyTorchJobs(namespace).Create(test.Ctx(), tuningJob, metav1.CreateOptions{})
test.Expect(err).NotTo(HaveOccurred())
test.T().Logf("Created PytorchJob %s/%s successfully", tuningJob.Namespace, tuningJob.Name)
Expand Down
6 changes: 6 additions & 0 deletions tests/kfto/core/support.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package core
import (
"embed"
"fmt"
"os"
"time"

. "github.com/onsi/gomega"
Expand Down Expand Up @@ -89,6 +90,11 @@ func GetOrCreateTestNamespace(t Test) string {
return namespaceName
}

func GetStorageBucketPrefix() (string, bool) {
storage_bucket_prefix, exists := os.LookupEnv("AWS_STORAGE_BUCKET_PREFIX")
return storage_bucket_prefix, exists
}

func uploadToS3(test Test, namespace string, pvcName string, storedAssetsPath string) {
defaultEndpoint, found := GetStorageBucketDefaultEndpoint()
test.Expect(found).To(BeTrue(), "Storage bucket default endpoint needs to be specified for S3 upload")
Expand Down

0 comments on commit c45a274

Please sign in to comment.