150 lines
7.1 KiB
Python
150 lines
7.1 KiB
Python
from database.database import db
|
|
from models.LabelStudioProject import LabelStudioProject
|
|
from models.Images import Image
|
|
from models.Annotation import Annotation
|
|
from services.fetch_labelstudio import fetch_label_studio_project, fetch_project_ids_and_titles
|
|
|
|
update_status = {"running": False}
|
|
|
|
def seed_label_studio():
|
|
"""Seed database with Label Studio project data"""
|
|
update_status["running"] = True
|
|
print('Seeding started')
|
|
|
|
try:
|
|
projects = fetch_project_ids_and_titles()
|
|
|
|
for project in projects:
|
|
print(f"Processing project {project['id']} ({project['title']})")
|
|
|
|
# Upsert project in DB
|
|
existing_project = LabelStudioProject.query.filter_by(project_id=project['id']).first()
|
|
if existing_project:
|
|
existing_project.title = project['title']
|
|
else:
|
|
new_project = LabelStudioProject(project_id=project['id'], title=project['title'])
|
|
db.session.add(new_project)
|
|
db.session.commit()
|
|
|
|
# Fetch project data (annotations array)
|
|
data = fetch_label_studio_project(project['id'])
|
|
if not isinstance(data, list) or len(data) == 0:
|
|
print(f"No annotation data for project {project['id']}")
|
|
continue
|
|
|
|
# Remove old images and annotations for this project
|
|
old_images = Image.query.filter_by(project_id=project['id']).all()
|
|
old_image_ids = [img.image_id for img in old_images]
|
|
|
|
if old_image_ids:
|
|
Annotation.query.filter(Annotation.image_id.in_(old_image_ids)).delete(synchronize_session=False)
|
|
Image.query.filter_by(project_id=project['id']).delete()
|
|
db.session.commit()
|
|
print(f"Deleted {len(old_image_ids)} old images and their annotations for project {project['id']}")
|
|
|
|
# Prepare arrays
|
|
images_bulk = []
|
|
anns_bulk = []
|
|
|
|
for ann in data:
|
|
# Extract width/height
|
|
width = None
|
|
height = None
|
|
|
|
if isinstance(ann.get('label_rectangles'), list) and len(ann['label_rectangles']) > 0:
|
|
width = ann['label_rectangles'][0].get('original_width')
|
|
height = ann['label_rectangles'][0].get('original_height')
|
|
elif isinstance(ann.get('label'), list) and len(ann['label']) > 0:
|
|
if ann['label'][0].get('original_width') and ann['label'][0].get('original_height'):
|
|
width = ann['label'][0]['original_width']
|
|
height = ann['label'][0]['original_height']
|
|
|
|
# Only process if width and height are valid
|
|
if width and height:
|
|
image_data = {
|
|
'project_id': project['id'],
|
|
'image_path': ann.get('image'),
|
|
'width': int(width), # Ensure integer
|
|
'height': int(height) # Ensure integer
|
|
}
|
|
images_bulk.append(image_data)
|
|
|
|
# Handle multiple annotations per image
|
|
if isinstance(ann.get('label_rectangles'), list):
|
|
for ann_detail in ann['label_rectangles']:
|
|
# Get label safely
|
|
rectanglelabels = ann_detail.get('rectanglelabels', [])
|
|
if isinstance(rectanglelabels, list) and len(rectanglelabels) > 0:
|
|
label = rectanglelabels[0]
|
|
elif isinstance(rectanglelabels, str):
|
|
label = rectanglelabels
|
|
else:
|
|
label = 'unknown'
|
|
|
|
ann_data = {
|
|
'image_path': ann.get('image'),
|
|
'x': (ann_detail['x'] * width) / 100,
|
|
'y': (ann_detail['y'] * height) / 100,
|
|
'width': (ann_detail['width'] * width) / 100,
|
|
'height': (ann_detail['height'] * height) / 100,
|
|
'Label': label
|
|
}
|
|
anns_bulk.append(ann_data)
|
|
elif isinstance(ann.get('label'), list):
|
|
for ann_detail in ann['label']:
|
|
# Get label safely
|
|
rectanglelabels = ann_detail.get('rectanglelabels', [])
|
|
if isinstance(rectanglelabels, list) and len(rectanglelabels) > 0:
|
|
label = rectanglelabels[0]
|
|
elif isinstance(rectanglelabels, str):
|
|
label = rectanglelabels
|
|
else:
|
|
label = 'unknown'
|
|
|
|
ann_data = {
|
|
'image_path': ann.get('image'),
|
|
'x': (ann_detail['x'] * width) / 100,
|
|
'y': (ann_detail['y'] * height) / 100,
|
|
'width': (ann_detail['width'] * width) / 100,
|
|
'height': (ann_detail['height'] * height) / 100,
|
|
'Label': label
|
|
}
|
|
anns_bulk.append(ann_data)
|
|
|
|
# Insert images and get generated IDs
|
|
inserted_images = []
|
|
for img_data in images_bulk:
|
|
new_image = Image(**img_data)
|
|
db.session.add(new_image)
|
|
db.session.flush() # Flush to get the ID
|
|
inserted_images.append(new_image)
|
|
db.session.commit()
|
|
|
|
# Map image_path -> image_id
|
|
image_map = {img.image_path: img.image_id for img in inserted_images}
|
|
|
|
# Assign correct image_id to each annotation
|
|
for ann_data in anns_bulk:
|
|
ann_data['image_id'] = image_map.get(ann_data['image_path'])
|
|
del ann_data['image_path']
|
|
|
|
# Insert annotations
|
|
for ann_data in anns_bulk:
|
|
new_annotation = Annotation(**ann_data)
|
|
db.session.add(new_annotation)
|
|
db.session.commit()
|
|
|
|
print(f"Inserted {len(images_bulk)} images and {len(anns_bulk)} annotations for project {project['id']}")
|
|
|
|
print('Seeding done')
|
|
return {'success': True, 'message': 'Data inserted successfully!'}
|
|
|
|
except Exception as error:
|
|
print(f'Error inserting data: {error}')
|
|
db.session.rollback()
|
|
return {'success': False, 'message': str(error)}
|
|
|
|
finally:
|
|
update_status["running"] = False
|
|
print('updateStatus.running set to false')
|