From 7c8ac5799fdee8733a52dafdad6a4d0df21f0399 Mon Sep 17 00:00:00 2001
From: "r.jaepel" <r.jaepel@fz-juelich.de>
Date: Mon, 4 Dec 2023 17:10:44 +0100
Subject: [PATCH] Refactor user input

---
 cadetrdm/initialize_repo.py | 10 +++++-----
 cadetrdm/io_utils.py        |  8 ++++++++
 cadetrdm/repositories.py    | 13 +++++++------
 3 files changed, 20 insertions(+), 11 deletions(-)

diff --git a/cadetrdm/initialize_repo.py b/cadetrdm/initialize_repo.py
index ecfa7df..9747b14 100644
--- a/cadetrdm/initialize_repo.py
+++ b/cadetrdm/initialize_repo.py
@@ -12,7 +12,7 @@ except ImportError:
 import cadetrdm
 from cadetrdm.repositories import ProjectRepo, OutputRepo
 from cadetrdm.web_utils import ssh_url_to_http_url
-from cadetrdm.io_utils import write_lines_to_file, is_tool
+from cadetrdm.io_utils import write_lines_to_file, is_tool, wait_for_user
 
 
 def init_lfs(lfs_filetypes: list, path: str = None):
@@ -130,10 +130,10 @@ def initialize_git(folder="."):
 
     try:
         repo = git.Repo(".")
-        proceed = input(f'The target directory already contains a git repo.\n'
-                        f'Please back up or push all changes to the repo before continuing.'
-                        f'Proceed? Y/n \n')
-        if not (proceed.lower() == "y" or proceed == ""):
+        proceed = wait_for_user('The target directory already contains a git repo.\n'
+                                'Please back up or push all changes to the repo before continuing.\n'
+                                'Proceed?')
+        if not proceed:
             raise KeyboardInterrupt
     except git.exc.InvalidGitRepositoryError:
         os.system(f"git init")
diff --git a/cadetrdm/io_utils.py b/cadetrdm/io_utils.py
index 07cdd0d..a45d9a6 100644
--- a/cadetrdm/io_utils.py
+++ b/cadetrdm/io_utils.py
@@ -70,3 +70,11 @@ def delete_path(filename):
         shutil.rmtree(absolute_path, onerror=remove_readonly)
     else:
         os.remove(absolute_path)
+
+
+def wait_for_user(message):
+    proceed = input(message + " Y/n \n")
+    if proceed.lower() == "y" or proceed == "":
+        return True
+    else:
+        return False
diff --git a/cadetrdm/repositories.py b/cadetrdm/repositories.py
index e305b53..7e3452e 100644
--- a/cadetrdm/repositories.py
+++ b/cadetrdm/repositories.py
@@ -14,7 +14,8 @@ from urllib.request import urlretrieve
 from tabulate import tabulate
 import pandas as pd
 
-from cadetrdm.io_utils import recursive_chmod, write_lines_to_file
+from cadetrdm.io_utils import recursive_chmod, write_lines_to_file, wait_for_user
+from cadetrdm.jupyter_functionality import Notebook
 from cadetrdm.version import version as cadetrdm_version
 
 try:
@@ -371,11 +372,11 @@ class BaseRepo:
         self.add(".")
 
     def reset_hard_to_head(self):
-        proceed = input(f'The output directory contains the following uncommitted changes:\n'
-                        f'{self.untracked_files + self.changed_files}\n'
-                        f' These will be lost if you continue\n'
-                        f'Proceed? Y/n \n')
-        if not (proceed.lower() == "y" or proceed == ""):
+        proceed = wait_for_user(f'The output directory contains the following uncommitted changes:\n'
+                                f'{self.untracked_files + self.changed_files}\n'
+                                f' These will be lost if you continue\n'
+                                f'Proceed?')
+        if not proceed:
             raise KeyboardInterrupt
         # reset all tracked files to previous commit, -q silences output
         self._git.reset("-q", "--hard", "HEAD")
-- 
GitLab