|
|
@@ -5,8 +5,8 @@ import os |
|
|
|
import sys |
|
|
|
from contextlib import contextmanager |
|
|
|
from pathlib import Path |
|
|
|
import pkg_resources as pkg |
|
|
|
|
|
|
|
import pkg_resources as pkg |
|
|
|
import yaml |
|
|
|
from tqdm import tqdm |
|
|
|
|
|
|
@@ -49,9 +49,11 @@ def check_wandb_dataset(data_file): |
|
|
|
if check_file(data_file) and data_file.endswith('.yaml'): |
|
|
|
with open(data_file, errors='ignore') as f: |
|
|
|
data_dict = yaml.safe_load(f) |
|
|
|
is_wandb_artifact = (data_dict['train'].startswith(WANDB_ARTIFACT_PREFIX) or |
|
|
|
data_dict['val'].startswith(WANDB_ARTIFACT_PREFIX)) |
|
|
|
if is_wandb_artifact: |
|
|
|
is_trainset_wandb_artifact = (isinstance(data_dict['train'], str) and |
|
|
|
data_dict['train'].startswith(WANDB_ARTIFACT_PREFIX)) |
|
|
|
is_valset_wandb_artifact = (isinstance(data_dict['val'], str) and |
|
|
|
data_dict['val'].startswith(WANDB_ARTIFACT_PREFIX)) |
|
|
|
if is_trainset_wandb_artifact or is_valset_wandb_artifact: |
|
|
|
return data_dict |
|
|
|
else: |
|
|
|
return check_dataset(data_file) |