diff --git a/atlas_server/src/db.py b/atlas_server/src/db.py index 3233fa65895faf626d281d1afe008c6dfb11b330..3c916f26ad141d1508a55eea8e9305ba84df55fa 100644 --- a/atlas_server/src/db.py +++ b/atlas_server/src/db.py @@ -63,9 +63,9 @@ class Database(object): if self.client: raise RuntimeError("Cannot open connection: Connection already open.") self.client = MongoClient(DATABASE_HOST, - username=DATABASE_USERNAME, - password=DATABASE_PASSWORD, - port=DATABASE_PORT) + username=DATABASE_USERNAME, + password=DATABASE_PASSWORD, + port=DATABASE_PORT) return self def close(self): @@ -88,6 +88,10 @@ class Database(object): def projects(self): return self.database.projects + @property + def project_ids(self): + return self.database.project_ids + def _update_project_fields(self, project_id, key_value_dict, operation="set"): """ Update a single field @@ -98,10 +102,10 @@ class Database(object): """ forbidden_fields = ("modified", "project_id", "created") operations = { - "set": "$set", - "add": "$push", - "delete": "$pull", - } + "set": "$set", + "add": "$push", + "delete": "$pull", + } if any(field in key_value_dict for field in forbidden_fields): message = f"Following fields may not be updated: {', '.join(field for field in key_value_dict if field in forbidden_fields)}" @@ -109,10 +113,10 @@ class Database(object): mongo_operation = operations[operation] update_res = self.projects.update_one({"project_id": project_id}, - { - mongo_operation: key_value_dict - }) - # Raise error if no projects were affected + { + mongo_operation: key_value_dict + }) + # Raise error if no projects were affected # This means the project id does not exist if operation == "delete": affected_count = update_res.deleted_count @@ -132,17 +136,17 @@ class Database(object): # Update modified data of project modified_date = _get_timestamp() self.projects.update_one({"project_id": project_id}, - { - "$set": - { - "modified": str(modified_date), - }, - }) + { + "$set": + { + "modified": str(modified_date), + }, + }) return modified_date def insert_project(self, project): # Find current maximum project id - max_existing = self.projects.find_one(sort=[("project_id", DESCENDING)]) + max_existing = self.project_ids.find_one(sort=[("project_id", DESCENDING)]) if max_existing: new_id = max_existing["project_id"] + 1 else: @@ -160,16 +164,18 @@ class Database(object): document = _project_to_mongodb_item(project=project) # Insert into database self.projects.insert_one(document) + # Insert project id + self.project_ids.insert_one({"project_id": new_id, "next_task_id": 1, }) return document def update_project_values(self, project_id, key_value_dict): self._update_project_fields(project_id=project_id, - key_value_dict=key_value_dict, - operation="set") + key_value_dict=key_value_dict, + operation="set") - def get_project_by_id(self, project_id): - """ + def get_project_by_id(self, project_id): + """ Get a project with a given id. Args: @@ -203,9 +209,9 @@ class Database(object): """ forbidden_fields = ("modified", "task_id", "created") operations = { - "set": "$set", - "add": "$push", - } + "set": "$set", + "add": "$push", + } if any(field in key_value_dict for field in forbidden_fields): message = f"Following fields may not be updated: {', '.join(field for field in key_value_dict if field in forbidden_fields)}" @@ -239,12 +245,12 @@ class Database(object): # Update modified data of project modified_date = _get_timestamp() self.projects.update_one({"project_id": project_id, "tasks.task_id": task_id, }, - { - "$set": - { - "task.$.modified": str(modified_date), - }, - }) + { + "$set": + { + "task.$.modified": str(modified_date), + }, + }) return modified_date # noinspection PyUnresolvedReferences @@ -253,10 +259,17 @@ class Database(object): project = self.get_project_by_id(project_id=project_id) if project is None: raise ProjectNotFoundError(project_id) - if not project.tasks: - new_task_id = 1 + project_ids = self.project_ids.find_one({"project_id": project_id}) + if not project_ids: + # Fallback for projects which have no entry in project_ids + if not project.tasks: + new_task_id = 1 + else: + new_task_id = max(task.task_id for task in project.tasks) + 1 + project_ids.insert_one({"project_id": project_id, "next_task_id": new_task_id}) else: - new_task_id = max(task.task_id for task in project.tasks) + 1 + new_task_id = project_ids.next_task_id + task.task_id = new_task_id # Add created and modified info to task task.created = _get_timestamp() @@ -265,53 +278,55 @@ class Database(object): document = _task_schema.dump(task) # Insert task into the projects task list self._update_project_fields(project_id=project_id, - key_value_dict={ - "tasks": document - }, - operation="add") + key_value_dict={ + "tasks": document + }, + operation="add") + # Update next task id + self.project_ids.update_one({"project_id": project_id}, {"$set": {"next_task_id": new_task_id + 1}}) return document def get_task_by_id(self, project_id, task_id): res = self.projects.find_one( - { - "project_id": project_id, - "tasks.task_id": task_id - }, - # Projection, only get the specific task - { - "tasks.$": 1, - } - ) + { + "project_id": project_id, + "tasks.task_id": task_id + }, + # Projection, only get the specific task + { + "tasks.$": 1, + } + ) if not res: raise TaskNotFoundError(project_id=project_id, task_id=task_id) return _task_schema.load(res["tasks"][0]) def update_task_values(self, project_id, task_id, key_value_dict, operation="set"): self._update_task_fields(project_id=project_id, - task_id=task_id, - key_value_dict=key_value_dict, - operation=operation) + task_id=task_id, + key_value_dict=key_value_dict, + operation=operation) - def delete_task(self, project_id, task_id): - # delete_res = self.projects. + def delete_task(self, project_id, task_id): + # delete_res = self.projects. self._delete_task(project_id=project_id, - task_id=task_id) + task_id=task_id) - def get_job_by_id(self, project_id, task_id, job_id): - # We cannot do double nested updates with mongoDB, so we get the task + def get_job_by_id(self, project_id, task_id, job_id): + # We cannot do double nested updates with mongoDB, so we get the task # and search for the task in the result set try: task = self.get_task_by_id(project_id=project_id, - task_id=task_id) + task_id=task_id) except TaskNotFoundError: raise JobNotFoundError(project_id=project_id, - task_id=task_id, - job_id=job_id) - # Search for the job + task_id=task_id, + job_id=job_id) + # Search for the job # noinspection PyUnresolvedReferences job = [job for job in task.jobs if job.job_id == job_id] if not job: raise JobNotFoundError(project_id=project_id, - task_id=task_id, - job_id=job_id) - return job[0] + task_id=task_id, + job_id=job_id) + return job[0]