Update general.py (#823)
Fixes #822 `init_seeds` from `torch_utils` import is being overwritten by function `init_seeds` in `general.py`
This commit is contained in:
parent
5a7d79fbe6
commit
455f7b8f76
|
|
@ -23,7 +23,8 @@ from scipy.cluster.vq import kmeans
|
||||||
from scipy.signal import butter, filtfilt
|
from scipy.signal import butter, filtfilt
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from utils.torch_utils import init_seeds, is_parallel
|
from utils.torch_utils import init_seeds as init_torch_seeds
|
||||||
|
from utils.torch_utils import is_parallel
|
||||||
|
|
||||||
# Set printoptions
|
# Set printoptions
|
||||||
torch.set_printoptions(linewidth=320, precision=5, profile='long')
|
torch.set_printoptions(linewidth=320, precision=5, profile='long')
|
||||||
|
|
@ -55,7 +56,7 @@ def set_logging(rank=-1):
|
||||||
def init_seeds(seed=0):
|
def init_seeds(seed=0):
|
||||||
random.seed(seed)
|
random.seed(seed)
|
||||||
np.random.seed(seed)
|
np.random.seed(seed)
|
||||||
init_seeds(seed=seed)
|
init_torch_seeds(seed=seed)
|
||||||
|
|
||||||
|
|
||||||
def get_latest_run(search_dir='./runs'):
|
def get_latest_run(search_dir='./runs'):
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue