207 lines
8.0 KiB
Python
Executable File
207 lines
8.0 KiB
Python
Executable File
"""
|
|
Training Queue Manager
|
|
Manages a queue of training jobs and tracks their progress
|
|
"""
|
|
import threading
|
|
import queue
|
|
import subprocess
|
|
import re
|
|
import os
|
|
from services.settings_service import get_setting
|
|
from models.training import Training
|
|
|
|
class TrainingQueueManager:
|
|
_instance = None
|
|
_lock = threading.Lock()
|
|
|
|
def __new__(cls):
|
|
if cls._instance is None:
|
|
with cls._lock:
|
|
if cls._instance is None:
|
|
cls._instance = super().__new__(cls)
|
|
cls._instance._initialized = False
|
|
return cls._instance
|
|
|
|
def __init__(self):
|
|
if self._initialized:
|
|
return
|
|
|
|
self.queue = queue.Queue()
|
|
self.current_training = None
|
|
self.current_process = None
|
|
self.worker_thread = None
|
|
self.running = False
|
|
self._initialized = True
|
|
|
|
# Start the worker thread
|
|
self.start_worker()
|
|
|
|
def start_worker(self):
|
|
"""Start the background worker thread"""
|
|
if self.worker_thread is None or not self.worker_thread.is_alive():
|
|
self.running = True
|
|
self.worker_thread = threading.Thread(target=self._process_queue, daemon=True)
|
|
self.worker_thread.start()
|
|
|
|
def add_to_queue(self, training_id, command, cwd):
|
|
"""Add a training job to the queue"""
|
|
job = {
|
|
'training_id': training_id,
|
|
'command': command,
|
|
'cwd': cwd,
|
|
'iteration': 0,
|
|
'max_epoch': 300 # Will be updated from training record
|
|
}
|
|
|
|
# Get max_epoch from training record
|
|
try:
|
|
training = Training.query.get(training_id)
|
|
if training:
|
|
job['max_epoch'] = training.max_epoch or 300
|
|
job['name'] = training.exp_name or f'Training {training_id}'
|
|
except:
|
|
pass
|
|
|
|
self.queue.put(job)
|
|
print(f'Added training {training_id} to queue. Queue size: {self.queue.qsize()}')
|
|
|
|
def _process_queue(self):
|
|
"""Worker thread that processes the queue"""
|
|
while self.running:
|
|
try:
|
|
# Wait for a job (blocking with timeout)
|
|
job = self.queue.get(timeout=1)
|
|
|
|
print(f'Starting training {job["training_id"]} from queue')
|
|
self.current_training = job
|
|
|
|
# Execute the training command
|
|
self._run_training(job)
|
|
|
|
# Mark as done
|
|
self.queue.task_done()
|
|
self.current_training = None
|
|
self.current_process = None
|
|
|
|
except queue.Empty:
|
|
continue
|
|
except Exception as e:
|
|
print(f'Error processing training job: {e}')
|
|
self.current_training = None
|
|
self.current_process = None
|
|
|
|
def _run_training(self, job):
|
|
"""Run a training command and monitor its output"""
|
|
try:
|
|
import platform
|
|
is_windows = platform.system() == 'Windows'
|
|
|
|
# Start process
|
|
self.current_process = subprocess.Popen(
|
|
job['command'],
|
|
shell=True,
|
|
cwd=job['cwd'],
|
|
stdout=subprocess.PIPE,
|
|
stderr=subprocess.STDOUT,
|
|
universal_newlines=True,
|
|
bufsize=1
|
|
)
|
|
|
|
# Monitor output for progress
|
|
for line in iter(self.current_process.stdout.readline, ''):
|
|
if line:
|
|
print(line.strip())
|
|
|
|
# Parse epoch and iteration from YOLOX output
|
|
# Example: "epoch: 3/300, iter: 90/101"
|
|
epoch_match = re.search(r'epoch:\s*(\d+)/(\d+)', line, re.IGNORECASE)
|
|
iter_match = re.search(r'iter:\s*(\d+)/(\d+)', line, re.IGNORECASE)
|
|
|
|
if epoch_match:
|
|
current_epoch = int(epoch_match.group(1))
|
|
total_epochs = int(epoch_match.group(2))
|
|
if self.current_training:
|
|
self.current_training['current_epoch'] = current_epoch
|
|
self.current_training['max_epoch'] = total_epochs
|
|
# Debug log
|
|
print(f'[PROGRESS] Parsed epoch: {current_epoch}/{total_epochs}')
|
|
|
|
if iter_match:
|
|
current_iter = int(iter_match.group(1))
|
|
total_iters = int(iter_match.group(2))
|
|
if self.current_training:
|
|
self.current_training['current_iter'] = current_iter
|
|
self.current_training['total_iters'] = total_iters
|
|
|
|
# Calculate overall progress percentage
|
|
if 'current_epoch' in self.current_training and 'max_epoch' in self.current_training:
|
|
epoch_progress = (self.current_training['current_epoch'] - 1) / self.current_training['max_epoch']
|
|
iter_progress = current_iter / total_iters / self.current_training['max_epoch']
|
|
total_progress = (epoch_progress + iter_progress) * 100
|
|
self.current_training['progress'] = round(total_progress, 2)
|
|
# Debug log
|
|
print(f'[PROGRESS] Epoch {self.current_training["current_epoch"]}/{self.current_training["max_epoch"]}, Iter {current_iter}/{total_iters}, Progress: {self.current_training["progress"]}%')
|
|
|
|
# Wait for completion
|
|
self.current_process.wait()
|
|
print(f'Training {job["training_id"]} completed with exit code {self.current_process.returncode}')
|
|
|
|
except Exception as e:
|
|
print(f'Error running training: {e}')
|
|
|
|
def get_status(self):
|
|
"""Get current status of training queue"""
|
|
queue_items = []
|
|
|
|
# Get items from queue without removing them
|
|
temp_items = []
|
|
while not self.queue.empty():
|
|
try:
|
|
item = self.queue.get_nowait()
|
|
temp_items.append(item)
|
|
queue_items.append({
|
|
'training_id': item['training_id'],
|
|
'name': item.get('name', f'Training {item["training_id"]}'),
|
|
'max_epoch': item.get('max_epoch', 300)
|
|
})
|
|
except queue.Empty:
|
|
break
|
|
|
|
# Put items back
|
|
for item in temp_items:
|
|
self.queue.put(item)
|
|
|
|
result = {
|
|
'current': None,
|
|
'queue': queue_items
|
|
}
|
|
|
|
if self.current_training:
|
|
current_epoch = self.current_training.get('current_epoch', 0)
|
|
max_epoch = self.current_training.get('max_epoch', 300)
|
|
result['current'] = {
|
|
'training_id': self.current_training['training_id'],
|
|
'name': self.current_training.get('name', f'Training {self.current_training["training_id"]}'),
|
|
'epoch': current_epoch, # For backward compatibility
|
|
'current_epoch': current_epoch,
|
|
'max_epoch': max_epoch,
|
|
'current_iter': self.current_training.get('current_iter', 0),
|
|
'total_iters': self.current_training.get('total_iters', 0),
|
|
'progress': self.current_training.get('progress', 0.0),
|
|
'iteration': current_epoch # For backward compatibility
|
|
}
|
|
|
|
return result
|
|
|
|
def stop(self):
|
|
"""Stop the worker thread"""
|
|
self.running = False
|
|
if self.current_process:
|
|
try:
|
|
self.current_process.terminate()
|
|
except:
|
|
pass
|
|
|
|
# Global instance
|
|
training_queue = TrainingQueueManager()
|