add multi trainings +fix path for coco retraining
This commit is contained in:
@@ -1,496 +0,0 @@
|
||||
const express = require('express');
|
||||
const multer = require('multer');
|
||||
const upload = multer();
|
||||
const TrainingProject = require('../models/TrainingProject.js');
|
||||
const LabelStudioProject = require('../models/LabelStudioProject.js')
|
||||
const { seedLabelStudio, updateStatus } = require('../services/seed-label-studio.js');
|
||||
const fs = require('fs');
|
||||
const path = require('path');
|
||||
const {generateTrainingJson} = require('../services/generate-json-yolox.js')
|
||||
|
||||
|
||||
const router = express.Router();
|
||||
|
||||
// Ensure JSON bodies are parsed for all routes
|
||||
router.use(express.json());
|
||||
|
||||
router.get('/seed', async (req, res) => {
|
||||
const result = await seedLabelStudio();
|
||||
res.json(result);
|
||||
});
|
||||
|
||||
|
||||
|
||||
// Trigger generate-json-yolox.js
|
||||
|
||||
router.post('/generate-yolox-json', async (req, res) => {
|
||||
const { project_id } = req.body;
|
||||
if (!project_id) {
|
||||
return res.status(400).json({ message: 'Missing project_id in request body' });
|
||||
}
|
||||
try {
|
||||
// Generate COCO JSONs
|
||||
// Find all TrainingProjectDetails for this project
|
||||
const TrainingProjectDetails = require('../models/TrainingProjectDetails.js');
|
||||
const detailsRows = await TrainingProjectDetails.findAll({ where: { project_id } });
|
||||
if (!detailsRows || detailsRows.length === 0) {
|
||||
return res.status(404).json({ message: 'No TrainingProjectDetails found for project ' + project_id });
|
||||
}
|
||||
// For each details row, generate coco.jsons and exp.py in projectfolder/project_details_id
|
||||
const Training = require('../models/training.js');
|
||||
const { saveYoloxExp } = require('../services/generate-yolox-exp.js');
|
||||
const TrainingProject = require('../models/TrainingProject.js');
|
||||
const trainingProject = await TrainingProject.findByPk(project_id);
|
||||
const projectName = trainingProject.name ? trainingProject.name.replace(/\s+/g, '_') : `project_${project_id}`;
|
||||
for (const details of detailsRows) {
|
||||
const detailsId = details.id;
|
||||
await generateTrainingJson(detailsId);
|
||||
const trainings = await Training.findAll({ where: { project_details_id: detailsId } });
|
||||
if (trainings.length === 0) continue;
|
||||
// For each training, save exp.py in projectfolder/project_details_id
|
||||
const outDir = path.join(__dirname, '..', projectName, String(detailsId));
|
||||
if (!fs.existsSync(outDir)) fs.mkdirSync(outDir, { recursive: true });
|
||||
for (const training of trainings) {
|
||||
const expFilePath = path.join(outDir, 'exp.py');
|
||||
await saveYoloxExp(training.id, expFilePath);
|
||||
}
|
||||
}
|
||||
|
||||
// Find all trainings for this project
|
||||
// ...existing code...
|
||||
res.json({ message: 'YOLOX JSON and exp.py generated for project ' + project_id });
|
||||
} catch (err) {
|
||||
console.error('Error generating YOLOX JSON:', err);
|
||||
res.status(500).json({ message: 'Failed to generate YOLOX JSON', error: err.message });
|
||||
}
|
||||
});
|
||||
|
||||
// Start YOLOX training
|
||||
const { spawn } = require('child_process');
|
||||
router.post('/start-yolox-training', async (req, res) => {
|
||||
try {
|
||||
const { project_id, training_id } = req.body;
|
||||
// Get project name
|
||||
const trainingProject = await TrainingProject.findByPk(project_id);
|
||||
const projectName = trainingProject.name ? trainingProject.name.replace(/\s+/g, '_') : `project_${project_id}`;
|
||||
// Look up training row by id or project_details_id
|
||||
const Training = require('../models/training.js');
|
||||
let trainingRow = await Training.findByPk(training_id);
|
||||
if (!trainingRow) {
|
||||
trainingRow = await Training.findOne({ where: { project_details_id: training_id } });
|
||||
}
|
||||
if (!trainingRow) {
|
||||
return res.status(404).json({ error: `Training row not found for id or project_details_id ${training_id}` });
|
||||
}
|
||||
const project_details_id = trainingRow.project_details_id;
|
||||
// Use the generated exp.py from the correct project folder
|
||||
const outDir = path.join(__dirname, '..', projectName, String(project_details_id));
|
||||
const yoloxMainDir = '/home/kitraining/Yolox/YOLOX-main';
|
||||
const expSrc = path.join(outDir, 'exp.py');
|
||||
if (!fs.existsSync(expSrc)) {
|
||||
return res.status(500).json({ error: `exp.py not found at ${expSrc}` });
|
||||
}
|
||||
// Source venv and run YOLOX training in YOLOX-main folder
|
||||
const yoloxVenv = '/home/kitraining/Yolox/yolox_venv/bin/activate';
|
||||
// Determine model argument based on selected_model and transfer_learning
|
||||
let modelArg = '';
|
||||
let cmd = '';
|
||||
if (
|
||||
trainingRow.transfer_learning &&
|
||||
typeof trainingRow.transfer_learning === 'string' &&
|
||||
trainingRow.transfer_learning.toLowerCase() === 'coco'
|
||||
) {
|
||||
// If transfer_learning is 'coco', add -o and modelArg
|
||||
modelArg = ` -c /home/kitraining/Yolox/YOLOX-main/pretrained/${trainingRow.selected_model}`;
|
||||
cmd = `bash -c 'source ${yoloxVenv} && python tools/train.py -f ${expSrc} -d 1 -b 8 --fp16 -o ${modelArg}.pth --cache'`;
|
||||
} else if (
|
||||
trainingRow.selected_model &&
|
||||
trainingRow.selected_model.toLowerCase() === 'coco' &&
|
||||
(!trainingRow.transfer_learning || trainingRow.transfer_learning === false)
|
||||
) {
|
||||
// If selected_model is 'coco' and not transfer_learning, add modelArg only
|
||||
modelArg = ` -c /pretrained/${trainingRow.selected_model}`;
|
||||
cmd = `bash -c 'source ${yoloxVenv} && python tools/train.py -f ${expSrc} -d 1 -b 8 --fp16 -o ${modelArg}.pth --cache'`;
|
||||
} else {
|
||||
// Default: no modelArg
|
||||
cmd = `bash -c 'source ${yoloxVenv} && python tools/train.py -f ${expSrc} -d 1 -b 8 --fp16' --cache`;
|
||||
}
|
||||
console.log(cmd)
|
||||
const child = spawn(cmd, { shell: true, cwd: yoloxMainDir });
|
||||
child.stdout.pipe(process.stdout);
|
||||
child.stderr.pipe(process.stderr);
|
||||
|
||||
res.json({ message: 'Training started' });
|
||||
} catch (err) {
|
||||
res.status(500).json({ error: 'Failed to start training', details: err.message });
|
||||
}
|
||||
});
|
||||
|
||||
// Get YOLOX training log
|
||||
router.get('/training-log', async (req, res) => {
|
||||
try {
|
||||
const { project_id, training_id } = req.query;
|
||||
const trainingProject = await TrainingProject.findByPk(project_id);
|
||||
const projectName = trainingProject.name ? trainingProject.name.replace(/\s+/g, '_') : `project_${project_id}`;
|
||||
const outDir = path.join(__dirname, '..', projectName, String(training_id));
|
||||
const logPath = path.join(outDir, 'training.log');
|
||||
if (!fs.existsSync(logPath)) {
|
||||
return res.status(404).json({ error: 'Log not found' });
|
||||
}
|
||||
const logData = fs.readFileSync(logPath, 'utf8');
|
||||
res.json({ log: logData });
|
||||
} catch (err) {
|
||||
res.status(500).json({ error: 'Failed to fetch log', details: err.message });
|
||||
}
|
||||
});
|
||||
|
||||
router.post('/training-projects', upload.single('project_image'), async (req, res) => {
|
||||
try {
|
||||
const { title, description } = req.body;
|
||||
const classes = JSON.parse(req.body.classes);
|
||||
const project_image = req.file ? req.file.buffer : null;
|
||||
const project_image_type = req.file ? req.file.mimetype : null;
|
||||
await TrainingProject.create({
|
||||
title,
|
||||
description,
|
||||
classes,
|
||||
project_image,
|
||||
project_image_type
|
||||
});
|
||||
res.json({ message: 'Project created!' });
|
||||
} catch (error) {
|
||||
console.error('Error creating project:', error);
|
||||
res.status(500).json({ message: 'Failed to create project', error: error.message });
|
||||
}
|
||||
});
|
||||
|
||||
router.get('/training-projects', async (req, res) => {
|
||||
try {
|
||||
const projects = await TrainingProject.findAll();
|
||||
// Convert BLOB to base64 data URL for each project
|
||||
const serialized = projects.map(project => {
|
||||
const plain = project.get({ plain: true });
|
||||
if (plain.project_image) {
|
||||
const base64 = Buffer.from(plain.project_image).toString('base64');
|
||||
const mimeType = plain.project_image_type || 'image/png';
|
||||
plain.project_image = `data:${mimeType};base64,${base64}`;
|
||||
}
|
||||
return plain;
|
||||
});
|
||||
res.json(serialized);
|
||||
} catch (error) {
|
||||
res.status(500).json({ message: 'Failed to fetch projects', error: error.message });
|
||||
}
|
||||
});
|
||||
|
||||
router.get('/update-status', async (req, res) => {
|
||||
res.json(updateStatus)
|
||||
})
|
||||
|
||||
router.get('/label-studio-projects', async (req, res) => {
|
||||
try {
|
||||
const LabelStudioProject = require('../models/LabelStudioProject.js');
|
||||
const Image = require('../models/Images.js');
|
||||
const Annotation = require('../models/Annotation.js');
|
||||
const labelStudioProjects = await LabelStudioProject.findAll();
|
||||
const projectsWithCounts = await Promise.all(labelStudioProjects.map(async project => {
|
||||
const plain = project.get({ plain: true });
|
||||
// Get all images for this project
|
||||
const images = await Image.findAll({ where: { project_id: plain.project_id } });
|
||||
let annotationCounts = {};
|
||||
if (images.length > 0) {
|
||||
const imageIds = images.map(img => img.image_id);
|
||||
// Get all annotations for these images
|
||||
const annotations = await Annotation.findAll({ where: { image_id: imageIds } });
|
||||
// Count by label
|
||||
for (const ann of annotations) {
|
||||
const label = ann.Label;
|
||||
annotationCounts[label] = (annotationCounts[label] || 0) + 1;
|
||||
}
|
||||
}
|
||||
plain.annotationCounts = annotationCounts;
|
||||
return plain;
|
||||
}));
|
||||
res.json(projectsWithCounts);
|
||||
} catch (error) {
|
||||
res.status(500).json({ message: 'Failed to fetch projects', error: error.message });
|
||||
}
|
||||
})
|
||||
|
||||
|
||||
// POST endpoint to create TrainingProjectDetails with all fields
|
||||
router.post('/training-project-details', async (req, res) => {
|
||||
try {
|
||||
const {
|
||||
project_id,
|
||||
annotation_projects,
|
||||
class_map,
|
||||
description
|
||||
} = req.body;
|
||||
if (!project_id || !annotation_projects) {
|
||||
return res.status(400).json({ message: 'Missing required fields' });
|
||||
}
|
||||
const TrainingProjectDetails = require('../models/TrainingProjectDetails.js');
|
||||
const created = await TrainingProjectDetails.create({
|
||||
project_id,
|
||||
annotation_projects,
|
||||
class_map: class_map || null,
|
||||
description: description || null
|
||||
});
|
||||
res.json({ message: 'TrainingProjectDetails created', details: created });
|
||||
} catch (error) {
|
||||
res.status(500).json({ message: 'Failed to create TrainingProjectDetails', error: error.message });
|
||||
}
|
||||
});
|
||||
|
||||
// GET endpoint to fetch all TrainingProjectDetails
|
||||
router.get('/training-project-details', async (req, res) => {
|
||||
try {
|
||||
const TrainingProjectDetails = require('../models/TrainingProjectDetails.js');
|
||||
const details = await TrainingProjectDetails.findAll();
|
||||
res.json(details);
|
||||
} catch (error) {
|
||||
res.status(500).json({ message: 'Failed to fetch TrainingProjectDetails', error: error.message });
|
||||
}
|
||||
});
|
||||
|
||||
// PUT endpoint to update class_map and description in TrainingProjectDetails
|
||||
router.put('/training-project-details', async (req, res) => {
|
||||
try {
|
||||
const { project_id, class_map, description } = req.body;
|
||||
if (!project_id || !class_map || !description) {
|
||||
return res.status(400).json({ message: 'Missing required fields' });
|
||||
}
|
||||
const TrainingProjectDetails = require('../models/TrainingProjectDetails.js');
|
||||
const details = await TrainingProjectDetails.findOne({ where: { project_id } });
|
||||
if (!details) {
|
||||
return res.status(404).json({ message: 'TrainingProjectDetails not found' });
|
||||
}
|
||||
details.class_map = class_map;
|
||||
details.description = description;
|
||||
await details.save();
|
||||
res.json({ message: 'Class map and description updated', details });
|
||||
} catch (error) {
|
||||
res.status(500).json({ message: 'Failed to update class map or description', error: error.message });
|
||||
}
|
||||
});
|
||||
|
||||
// POST endpoint to receive YOLOX settings and save to DB (handles multipart/form-data)
|
||||
router.post('/yolox-settings', upload.any(), async (req, res) => {
|
||||
try {
|
||||
const settings = req.body;
|
||||
// Debug: Log all received fields and types
|
||||
console.log('--- YOLOX settings received ---');
|
||||
console.log('settings:', settings);
|
||||
if (req.files && req.files.length > 0) {
|
||||
console.log('Files received:', req.files.map(f => ({ fieldname: f.fieldname, originalname: f.originalname, size: f.size })));
|
||||
}
|
||||
// Declare requiredFields once
|
||||
const requiredFields = ['project_details_id', 'exp_name', 'max_epoch', 'depth', 'width', 'activation', 'train', 'valid', 'test', 'selected_model', 'transfer_learning'];
|
||||
// Log types of required fields
|
||||
requiredFields.forEach(field => {
|
||||
console.log(`Field '${field}': value='${settings[field]}', type='${typeof settings[field]}'`);
|
||||
});
|
||||
// Map select_model to selected_model if present
|
||||
if (settings && settings.select_model && !settings.selected_model) {
|
||||
settings.selected_model = settings.select_model;
|
||||
delete settings.select_model;
|
||||
}
|
||||
// Lookup project_details_id from project_id
|
||||
if (!settings.project_id || isNaN(Number(settings.project_id))) {
|
||||
throw new Error('Missing or invalid project_id in request. Cannot assign training to a project.');
|
||||
}
|
||||
const TrainingProjectDetails = require('../models/TrainingProjectDetails.js');
|
||||
let details = await TrainingProjectDetails.findOne({ where: { project_id: settings.project_id } });
|
||||
if (!details) {
|
||||
details = await TrainingProjectDetails.create({
|
||||
project_id: settings.project_id,
|
||||
annotation_projects: [],
|
||||
class_map: null,
|
||||
description: null
|
||||
});
|
||||
}
|
||||
settings.project_details_id = details.id;
|
||||
// Map 'act' from frontend to 'activation' for DB
|
||||
if (settings.act !== undefined) {
|
||||
settings.activation = settings.act;
|
||||
delete settings.act;
|
||||
}
|
||||
// Type conversion for DB compatibility
|
||||
[
|
||||
'max_epoch', 'depth', 'width', 'warmup_epochs', 'warmup_lr', 'no_aug_epochs', 'min_lr_ratio', 'weight_decay', 'momentum', 'print_interval', 'eval_interval', 'test_conf', 'nmsthre', 'multiscale_range', 'degrees', 'translate', 'shear', 'train', 'valid', 'test'
|
||||
].forEach(f => {
|
||||
if (settings[f] !== undefined) settings[f] = Number(settings[f]);
|
||||
});
|
||||
// Improved boolean conversion
|
||||
['ema', 'enable_mixup', 'save_history_ckpt'].forEach(f => {
|
||||
if (settings[f] !== undefined) {
|
||||
if (typeof settings[f] === 'string') {
|
||||
settings[f] = settings[f].toLowerCase() === 'true';
|
||||
} else {
|
||||
settings[f] = Boolean(settings[f]);
|
||||
}
|
||||
}
|
||||
});
|
||||
// Improved array conversion
|
||||
['mosaic_scale', 'mixup_scale', 'scale'].forEach(f => {
|
||||
if (settings[f] && typeof settings[f] === 'string') {
|
||||
settings[f] = settings[f]
|
||||
.split(',')
|
||||
.map(s => Number(s.trim()))
|
||||
.filter(n => !isNaN(n));
|
||||
}
|
||||
});
|
||||
// Trim all string fields
|
||||
Object.keys(settings).forEach(f => {
|
||||
if (typeof settings[f] === 'string') settings[f] = settings[f].trim();
|
||||
});
|
||||
// Set default for transfer_learning if missing
|
||||
if (settings.transfer_learning === undefined) settings.transfer_learning = false;
|
||||
// Convert empty string seed to null
|
||||
if ('seed' in settings && (settings.seed === '' || settings.seed === undefined)) {
|
||||
settings.seed = null;
|
||||
}
|
||||
// Validate required fields for training table
|
||||
for (const field of requiredFields) {
|
||||
if (settings[field] === undefined || settings[field] === null || settings[field] === '') {
|
||||
console.error('Missing required field:', field, 'Value:', settings[field]);
|
||||
throw new Error('Missing required field: ' + field);
|
||||
}
|
||||
}
|
||||
console.log('Received YOLOX settings:', settings);
|
||||
// Handle uploaded model file (ckpt_upload)
|
||||
if (req.files && req.files.length > 0) {
|
||||
const ckptFile = req.files.find(f => f.fieldname === 'ckpt_upload');
|
||||
if (ckptFile) {
|
||||
const uploadDir = path.join(__dirname, '..', 'uploads');
|
||||
if (!fs.existsSync(uploadDir)) fs.mkdirSync(uploadDir);
|
||||
const filename = ckptFile.originalname || `uploaded_model_${settings.project_id}.pth`;
|
||||
const filePath = path.join(uploadDir, filename);
|
||||
fs.writeFileSync(filePath, ckptFile.buffer);
|
||||
settings.model_upload = filePath;
|
||||
}
|
||||
}
|
||||
// Save settings to DB only (no file)
|
||||
const { pushYoloxExpToDb } = require('../services/push-yolox-exp.js');
|
||||
const training = await pushYoloxExpToDb(settings);
|
||||
res.json({ message: 'YOLOX settings saved to DB', training });
|
||||
} catch (error) {
|
||||
console.error('Error in /api/yolox-settings:', error.stack || error);
|
||||
res.status(500).json({ message: 'Failed to save YOLOX settings', error: error.message });
|
||||
}
|
||||
});
|
||||
|
||||
// POST endpoint to receive binary model file and save to disk (not DB)
|
||||
router.post('/yolox-settings/upload', async (req, res) => {
|
||||
try {
|
||||
const projectId = req.query.project_id;
|
||||
if (!projectId) return res.status(400).json({ message: 'Missing project_id in query' });
|
||||
// Save file to disk
|
||||
const uploadDir = path.join(__dirname, '..', 'uploads');
|
||||
if (!fs.existsSync(uploadDir)) fs.mkdirSync(uploadDir);
|
||||
const filename = req.headers['x-upload-filename'] || `uploaded_model_${projectId}.pth`;
|
||||
const filePath = path.join(uploadDir, filename);
|
||||
const chunks = [];
|
||||
req.on('data', chunk => chunks.push(chunk));
|
||||
req.on('end', async () => {
|
||||
const buffer = Buffer.concat(chunks);
|
||||
fs.writeFile(filePath, buffer, async err => {
|
||||
if (err) {
|
||||
console.error('Error saving file:', err);
|
||||
return res.status(500).json({ message: 'Failed to save model file', error: err.message });
|
||||
}
|
||||
// Update latest training row for this project with file path
|
||||
try {
|
||||
const TrainingProjectDetails = require('../models/TrainingProjectDetails.js');
|
||||
const Training = require('../models/training.js');
|
||||
// Find details row for this project
|
||||
const details = await TrainingProjectDetails.findOne({ where: { project_id: projectId } });
|
||||
if (!details) return res.status(404).json({ message: 'No TrainingProjectDetails found for project_id' });
|
||||
// Find latest training for this details row
|
||||
const training = await Training.findOne({ where: { project_details_id: details.id }, order: [['createdAt', 'DESC']] });
|
||||
if (!training) return res.status(404).json({ message: 'No training found for project_id' });
|
||||
// Save file path to model_upload field
|
||||
training.model_upload = filePath;
|
||||
await training.save();
|
||||
res.json({ message: 'Model file uploaded and saved to disk', filename, trainingId: training.id });
|
||||
} catch (dbErr) {
|
||||
console.error('Error updating training with file path:', dbErr);
|
||||
res.status(500).json({ message: 'File saved but failed to update training row', error: dbErr.message });
|
||||
}
|
||||
});
|
||||
});
|
||||
} catch (error) {
|
||||
console.error('Error in /api/yolox-settings/upload:', error.stack || error);
|
||||
res.status(500).json({ message: 'Failed to upload model file', error: error.message });
|
||||
}
|
||||
});
|
||||
|
||||
// GET endpoint to fetch all trainings (optionally filtered by project_id)
|
||||
router.get('/trainings', async (req, res) => {
|
||||
try {
|
||||
const project_id = req.query.project_id;
|
||||
const TrainingProjectDetails = require('../models/TrainingProjectDetails.js');
|
||||
const Training = require('../models/training.js');
|
||||
if (project_id) {
|
||||
// Find all details rows for this project
|
||||
const detailsRows = await TrainingProjectDetails.findAll({ where: { project_id } });
|
||||
if (!detailsRows || detailsRows.length === 0) return res.json([]);
|
||||
// Get all trainings linked to any details row for this project
|
||||
const detailsIds = detailsRows.map(d => d.id);
|
||||
const trainings = await Training.findAll({ where: { project_details_id: detailsIds } });
|
||||
return res.json(trainings);
|
||||
} else {
|
||||
// Return all trainings if no project_id is specified
|
||||
const trainings = await Training.findAll();
|
||||
return res.json(trainings);
|
||||
}
|
||||
} catch (error) {
|
||||
res.status(500).json({ message: 'Failed to fetch trainings', error: error.message });
|
||||
}
|
||||
});
|
||||
|
||||
// DELETE endpoint to remove a training by id
|
||||
router.delete('/trainings/:id', async (req, res) => {
|
||||
try {
|
||||
const Training = require('../models/training.js');
|
||||
const id = req.params.id;
|
||||
const deleted = await Training.destroy({ where: { id } });
|
||||
if (deleted) {
|
||||
res.json({ message: 'Training deleted' });
|
||||
} else {
|
||||
res.status(404).json({ message: 'Training not found' });
|
||||
}
|
||||
} catch (error) {
|
||||
res.status(500).json({ message: 'Failed to delete training', error: error.message });
|
||||
}
|
||||
});
|
||||
|
||||
// DELETE endpoint to remove a training project and all related entries
|
||||
router.delete('/training-projects/:id', async (req, res) => {
|
||||
try {
|
||||
const projectId = req.params.id;
|
||||
const TrainingProject = require('../models/TrainingProject.js');
|
||||
const TrainingProjectDetails = require('../models/TrainingProjectDetails.js');
|
||||
const Training = require('../models/training.js');
|
||||
// Find details row(s) for this project
|
||||
const detailsRows = await TrainingProjectDetails.findAll({ where: { project_id: projectId } });
|
||||
const detailsIds = detailsRows.map(d => d.id);
|
||||
// Delete all trainings linked to these details
|
||||
if (detailsIds.length > 0) {
|
||||
await Training.destroy({ where: { project_details_id: detailsIds } });
|
||||
await TrainingProjectDetails.destroy({ where: { project_id: projectId } });
|
||||
}
|
||||
// Delete the project itself
|
||||
const deleted = await TrainingProject.destroy({ where: { project_id: projectId } });
|
||||
if (deleted) {
|
||||
res.json({ message: 'Training project and all related entries deleted' });
|
||||
} else {
|
||||
res.status(404).json({ message: 'Training project not found' });
|
||||
}
|
||||
} catch (error) {
|
||||
res.status(500).json({ message: 'Failed to delete training project', error: error.message });
|
||||
}
|
||||
});
|
||||
|
||||
module.exports = router;
|
||||
@@ -130,11 +130,17 @@ def start_yolox_training():
|
||||
if (training.transfer_learning and
|
||||
isinstance(training.transfer_learning, str) and
|
||||
training.transfer_learning.lower() == 'coco'):
|
||||
model_arg = f'-c {yolox_main_dir}/pretrained/{training.selected_model}.pth'
|
||||
# Use yolox_path setting to construct pretrained model path
|
||||
model_path = os.path.join(yolox_main_dir, 'pretrained', f'{training.selected_model}.pth')
|
||||
model_path = model_path.replace('\\', '/') # Use forward slashes for command line
|
||||
model_arg = f'-c {model_path}'
|
||||
elif (training.selected_model and
|
||||
training.selected_model.lower() == 'coco' and
|
||||
(not training.transfer_learning or training.transfer_learning == False)):
|
||||
model_arg = f'-c {yolox_main_dir}/pretrained/{training.selected_model}.pth'
|
||||
# Use yolox_path setting to construct pretrained model path
|
||||
model_path = os.path.join(yolox_main_dir, 'pretrained', f'{training.selected_model}.pth')
|
||||
model_path = model_path.replace('\\', '/') # Use forward slashes for command line
|
||||
model_arg = f'-c {model_path}'
|
||||
|
||||
# Build base training arguments
|
||||
train_args = f'-f {exp_file_path} -d 1 -b 8 --fp16 --cache'
|
||||
|
||||
Reference in New Issue
Block a user