feat: 系统优化和功能完善

主要更新:
- 调整并发配置为50人(数据库连接池30,Tomcat线程150,异步线程池5/20)
- 实现无界阻塞队列(LinkedBlockingQueue)任务处理
- 实现分镜视频保存功能(保存到uploads目录)
- 统一管理页面导航栏和右上角样式
- 添加日活用户统计功能
- 优化视频拼接和保存逻辑
- 添加部署文档和快速部署指南
- 更新.gitignore排除敏感配置文件
This commit is contained in:
AIGC Developer
2025-11-07 19:09:50 +08:00
parent b5820d9be2
commit 1e71ae6a26
146 changed files with 10720 additions and 3032 deletions

View File

@@ -2,11 +2,9 @@ package com.example.demo;
import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.scheduling.annotation.EnableAsync;
import org.springframework.scheduling.annotation.EnableScheduling;
@SpringBootApplication
@EnableAsync
@EnableScheduling
public class DemoApplication {
@@ -17,6 +15,13 @@ public class DemoApplication {
System.setProperty("sun.net.client.defaultConnectTimeout", "30000");
System.setProperty("sun.net.client.defaultReadTimeout", "120000");
// 增加HTTP缓冲区大小以支持大请求体Base64编码的图片可能很大
// 设置Socket缓冲区大小为10MB
System.setProperty("java.net.preferIPv4Stack", "true");
// Apache HttpClient 使用系统属性
System.setProperty("org.apache.http.client.connection.timeout", "30000");
System.setProperty("org.apache.http.socket.timeout", "300000");
SpringApplication.run(DemoApplication.class, args);
}

View File

@@ -0,0 +1,46 @@
package com.example.demo.config;
import java.util.concurrent.Executor;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.scheduling.annotation.EnableAsync;
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
/**
* 异步执行器配置
* 支持50人并发处理异步任务如视频生成、图片处理等
*/
@Configuration
@EnableAsync
public class AsyncConfig {
/**
* 配置异步任务执行器
* 核心线程数5最大线程数20队列容量50
* 可支持50人并发每个用户最多3个任务共150个任务
* 大部分任务在队列中等待,实际并发执行的任务数量受线程池限制
*/
@Bean(name = "taskExecutor")
public Executor taskExecutor() {
ThreadPoolTaskExecutor executor = new ThreadPoolTaskExecutor();
// 核心线程数:保持活跃的最小线程数
executor.setCorePoolSize(5);
// 最大线程数:最大并发执行的任务数
executor.setMaxPoolSize(20);
// 队列容量:等待执行的任务数
executor.setQueueCapacity(50);
// 线程名前缀
executor.setThreadNamePrefix("async-task-");
// 拒绝策略:当线程池和队列都满时,使用调用者线程执行(保证任务不丢失)
executor.setRejectedExecutionHandler(new java.util.concurrent.ThreadPoolExecutor.CallerRunsPolicy());
// 等待所有任务完成后再关闭线程池
executor.setWaitForTasksToCompleteOnShutdown(true);
// 等待时间(秒)
executor.setAwaitTerminationSeconds(60);
executor.initialize();
return executor;
}
}

View File

@@ -19,8 +19,8 @@ public class PollingConfig implements SchedulingConfigurer {
@Override
public void configureTasks(@NonNull ScheduledTaskRegistrar taskRegistrar) {
// 使用自定义线程池执行定时任务
ScheduledExecutorService executor = Executors.newScheduledThreadPool(2);
// 使用自定义线程池执行定时任务支持50人并发
ScheduledExecutorService executor = Executors.newScheduledThreadPool(5);
taskRegistrar.setScheduler(executor);
}
}

View File

@@ -50,6 +50,7 @@ public class SecurityConfig {
.requestMatchers("/api/image-to-video/**").authenticated() // 图生视频接口需要认证
.requestMatchers("/api/text-to-video/**").authenticated() // 文生视频接口需要认证
.requestMatchers("/api/dashboard/**").hasRole("ADMIN") // 仪表盘API需要管理员权限
.requestMatchers("/api/admin/**").hasRole("ADMIN") // 管理员API需要管理员权限
.requestMatchers("/settings", "/settings/**").hasRole("ADMIN")
.requestMatchers("/users/**").hasRole("ADMIN")
.anyRequest().authenticated()

View File

@@ -0,0 +1,92 @@
package com.example.demo.config;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.orm.jpa.JpaTransactionManager;
import org.springframework.transaction.PlatformTransactionManager;
import org.springframework.transaction.annotation.EnableTransactionManagement;
import org.springframework.transaction.support.TransactionTemplate;
import jakarta.persistence.EntityManagerFactory;
/**
* 事务管理器配置
* 确保事务不会长时间占用数据库连接
*/
@Configuration
@EnableTransactionManagement
public class TransactionManagerConfig {
@Autowired
private EntityManagerFactory entityManagerFactory;
/**
* 配置事务管理器
* 使用 JpaTransactionManager 以支持 JPA 操作(包括悲观锁)
* 注意:超时时间在 TransactionTemplate 中设置,而不是在 TransactionManager 中
* 这样可以更精确地控制不同场景下的超时时间
*/
@Bean
public PlatformTransactionManager transactionManager() {
JpaTransactionManager transactionManager = new JpaTransactionManager();
transactionManager.setEntityManagerFactory(entityManagerFactory);
// 设置是否允许嵌套事务
transactionManager.setNestedTransactionAllowed(true);
// 设置是否在回滚时验证事务状态
transactionManager.setValidateExistingTransaction(true);
// 注意:超时时间在 TransactionTemplate 中设置,而不是在 TransactionManager 中
// 这样可以更精确地控制不同场景下的超时时间
return transactionManager;
}
/**
* 配置用于异步方法的事务模板
* 使用更短的超时时间3秒确保异步线程中的事务快速完成
*/
@Bean(name = "asyncTransactionTemplate")
public TransactionTemplate asyncTransactionTemplate(PlatformTransactionManager transactionManager) {
TransactionTemplate template = new TransactionTemplate(transactionManager);
// 异步方法中的事务超时时间设置为3秒
// 确保异步线程中的数据库操作快速完成,不会长时间占用连接
template.setTimeout(3);
// 设置传播行为为 REQUIRES_NEW确保每个操作都是独立事务
template.setPropagationBehavior(org.springframework.transaction.TransactionDefinition.PROPAGATION_REQUIRES_NEW);
// 设置隔离级别为 READ_COMMITTED默认
template.setIsolationLevel(org.springframework.transaction.TransactionDefinition.ISOLATION_READ_COMMITTED);
// 设置只读标志默认false允许写操作
template.setReadOnly(false);
return template;
}
/**
* 配置用于只读操作的事务模板
* 使用更短的超时时间2秒确保只读操作快速完成
*/
@Bean(name = "readOnlyTransactionTemplate")
public TransactionTemplate readOnlyTransactionTemplate(PlatformTransactionManager transactionManager) {
TransactionTemplate template = new TransactionTemplate(transactionManager);
// 只读操作超时时间设置为2秒
template.setTimeout(2);
// 设置传播行为为 REQUIRES_NEW
template.setPropagationBehavior(org.springframework.transaction.TransactionDefinition.PROPAGATION_REQUIRES_NEW);
// 设置只读标志为 true
template.setReadOnly(true);
return template;
}
}

View File

@@ -2,10 +2,12 @@ package com.example.demo.config;
import java.util.Locale;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.servlet.LocaleResolver;
import org.springframework.web.servlet.config.annotation.InterceptorRegistry;
import org.springframework.web.servlet.config.annotation.ResourceHandlerRegistry;
import org.springframework.web.servlet.config.annotation.WebMvcConfigurer;
import org.springframework.web.servlet.i18n.LocaleChangeInterceptor;
import org.springframework.web.servlet.i18n.SessionLocaleResolver;
@@ -13,6 +15,9 @@ import org.springframework.web.servlet.i18n.SessionLocaleResolver;
@Configuration
public class WebMvcConfig implements WebMvcConfigurer {
@Value("${app.upload.path:uploads}")
private String uploadPath;
@Bean
public LocaleResolver localeResolver() {
SessionLocaleResolver slr = new SessionLocaleResolver();
@@ -32,6 +37,28 @@ public class WebMvcConfig implements WebMvcConfigurer {
registry.addInterceptor(localeChangeInterceptor());
}
/**
* 配置静态资源服务使上传的文件可以通过URL访问
* 访问路径:/uploads/** -> 映射到 uploads/ 目录
*/
@Override
public void addResourceHandlers(ResourceHandlerRegistry registry) {
// 将 /uploads/** 映射到 uploads/ 目录
// 处理路径:如果是相对路径,转换为绝对路径;如果是绝对路径,直接使用
java.nio.file.Path uploadDirPath = java.nio.file.Paths.get(uploadPath);
if (!uploadDirPath.isAbsolute()) {
// 相对路径:基于应用运行目录
uploadDirPath = java.nio.file.Paths.get(System.getProperty("user.dir"), uploadPath);
}
// 确保路径使用正斜杠URL格式
String resourceLocation = "file:" + uploadDirPath.toAbsolutePath().toString().replace("\\", "/") + "/";
registry.addResourceHandler("/uploads/**")
.addResourceLocations(resourceLocation)
.setCachePeriod(3600); // 缓存1小时
}
// CORS配置已移至SecurityConfig避免冲突
}

View File

@@ -1,15 +1,28 @@
package com.example.demo.controller;
import com.example.demo.service.UserService;
import com.example.demo.util.JwtUtils;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.http.ResponseEntity;
import org.springframework.web.bind.annotation.*;
import org.springframework.web.bind.annotation.DeleteMapping;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.PathVariable;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.PutMapping;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestHeader;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;
import java.util.HashMap;
import java.util.Map;
import com.example.demo.model.User;
import com.example.demo.service.UserService;
import com.example.demo.util.JwtUtils;
/**
* 管理员控制器
@@ -129,5 +142,241 @@ public class AdminController {
return null;
}
}
/**
* 获取所有用户列表
*/
@GetMapping("/users")
public ResponseEntity<Map<String, Object>> getAllUsers(
@RequestHeader("Authorization") String token) {
Map<String, Object> response = new HashMap<>();
try {
// 验证管理员权限
String adminUsername = extractUsernameFromToken(token);
if (adminUsername == null) {
response.put("success", false);
response.put("message", "用户未登录");
return ResponseEntity.status(401).body(response);
}
// 获取所有用户
List<User> users = userService.findAll();
// 转换为DTO格式
List<Map<String, Object>> userList = users.stream().map(user -> {
Map<String, Object> userMap = new HashMap<>();
userMap.put("id", user.getId());
userMap.put("username", user.getUsername());
userMap.put("email", user.getEmail());
userMap.put("role", user.getRole());
userMap.put("points", user.getPoints());
userMap.put("frozenPoints", user.getFrozenPoints());
userMap.put("createdAt", user.getCreatedAt());
userMap.put("lastLoginAt", user.getLastLoginAt());
userMap.put("isActive", user.getIsActive());
return userMap;
}).collect(Collectors.toList());
response.put("success", true);
response.put("data", userList);
logger.info("管理员 {} 获取用户列表,共 {} 个用户", adminUsername, users.size());
return ResponseEntity.ok(response);
} catch (Exception e) {
logger.error("获取用户列表失败", e);
response.put("success", false);
response.put("message", "获取用户列表失败: " + e.getMessage());
return ResponseEntity.status(500).body(response);
}
}
/**
* 创建新用户
*/
@PostMapping("/users")
public ResponseEntity<Map<String, Object>> createUser(
@RequestBody Map<String, Object> userData,
@RequestHeader("Authorization") String token) {
Map<String, Object> response = new HashMap<>();
try {
// 验证管理员权限
String adminUsername = extractUsernameFromToken(token);
if (adminUsername == null) {
response.put("success", false);
response.put("message", "用户未登录");
return ResponseEntity.status(401).body(response);
}
// 提取用户数据
String username = (String) userData.get("username");
String email = (String) userData.get("email");
String password = (String) userData.get("password");
String role = (String) userData.getOrDefault("role", "ROLE_USER");
// 验证必填字段
if (username == null || username.isBlank()) {
response.put("success", false);
response.put("message", "用户名不能为空");
return ResponseEntity.badRequest().body(response);
}
if (email == null || email.isBlank()) {
response.put("success", false);
response.put("message", "邮箱不能为空");
return ResponseEntity.badRequest().body(response);
}
if (password == null || password.isBlank()) {
response.put("success", false);
response.put("message", "密码不能为空");
return ResponseEntity.badRequest().body(response);
}
// 创建用户
User user = userService.create(username, email, password);
// 如果指定了角色,更新角色
if (!role.equals("ROLE_USER")) {
userService.update(user.getId(), username, email, null, role);
user = userService.findById(user.getId());
}
// 构建响应
Map<String, Object> userMap = new HashMap<>();
userMap.put("id", user.getId());
userMap.put("username", user.getUsername());
userMap.put("email", user.getEmail());
userMap.put("role", user.getRole());
userMap.put("points", user.getPoints());
userMap.put("createdAt", user.getCreatedAt());
response.put("success", true);
response.put("message", "用户创建成功");
response.put("data", userMap);
logger.info("管理员 {} 创建用户: {}", adminUsername, username);
return ResponseEntity.ok(response);
} catch (IllegalArgumentException e) {
logger.error("创建用户失败: {}", e.getMessage());
response.put("success", false);
response.put("message", e.getMessage());
return ResponseEntity.badRequest().body(response);
} catch (Exception e) {
logger.error("创建用户失败", e);
response.put("success", false);
response.put("message", "创建用户失败: " + e.getMessage());
return ResponseEntity.status(500).body(response);
}
}
/**
* 更新用户信息
*/
@PutMapping("/users/{id}")
public ResponseEntity<Map<String, Object>> updateUser(
@PathVariable Long id,
@RequestBody Map<String, Object> userData,
@RequestHeader("Authorization") String token) {
Map<String, Object> response = new HashMap<>();
try {
// 验证管理员权限
String adminUsername = extractUsernameFromToken(token);
if (adminUsername == null) {
response.put("success", false);
response.put("message", "用户未登录");
return ResponseEntity.status(401).body(response);
}
// 提取用户数据
String username = (String) userData.get("username");
String email = (String) userData.get("email");
String password = (String) userData.get("password");
String role = (String) userData.get("role");
// 验证必填字段
if (username == null || username.isBlank()) {
response.put("success", false);
response.put("message", "用户名不能为空");
return ResponseEntity.badRequest().body(response);
}
if (email == null || email.isBlank()) {
response.put("success", false);
response.put("message", "邮箱不能为空");
return ResponseEntity.badRequest().body(response);
}
// 更新用户(密码可选)
User user = userService.update(id, username, email, password, role);
// 构建响应
Map<String, Object> userMap = new HashMap<>();
userMap.put("id", user.getId());
userMap.put("username", user.getUsername());
userMap.put("email", user.getEmail());
userMap.put("role", user.getRole());
userMap.put("points", user.getPoints());
userMap.put("updatedAt", user.getUpdatedAt());
response.put("success", true);
response.put("message", "用户更新成功");
response.put("data", userMap);
logger.info("管理员 {} 更新用户: {}", adminUsername, username);
return ResponseEntity.ok(response);
} catch (IllegalArgumentException e) {
logger.error("更新用户失败: {}", e.getMessage());
response.put("success", false);
response.put("message", e.getMessage());
return ResponseEntity.badRequest().body(response);
} catch (Exception e) {
logger.error("更新用户失败", e);
response.put("success", false);
response.put("message", "更新用户失败: " + e.getMessage());
return ResponseEntity.status(500).body(response);
}
}
/**
* 删除用户
*/
@DeleteMapping("/users/{id}")
public ResponseEntity<Map<String, Object>> deleteUser(
@PathVariable Long id,
@RequestHeader("Authorization") String token) {
Map<String, Object> response = new HashMap<>();
try {
// 验证管理员权限
String adminUsername = extractUsernameFromToken(token);
if (adminUsername == null) {
response.put("success", false);
response.put("message", "用户未登录");
return ResponseEntity.status(401).body(response);
}
// 删除用户
userService.delete(id);
response.put("success", true);
response.put("message", "用户删除成功");
logger.info("管理员 {} 删除用户ID: {}", adminUsername, id);
return ResponseEntity.ok(response);
} catch (Exception e) {
logger.error("删除用户失败", e);
response.put("success", false);
response.put("message", "删除用户失败: " + e.getMessage());
return ResponseEntity.status(500).body(response);
}
}
}

View File

@@ -0,0 +1,213 @@
package com.example.demo.controller;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.HashMap;
import java.util.Map;
import java.util.Properties;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.core.io.ClassPathResource;
import org.springframework.http.ResponseEntity;
import org.springframework.web.bind.annotation.CrossOrigin;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.PutMapping;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
@RestController
@RequestMapping("/api/api-key")
@CrossOrigin(origins = "*")
public class ApiKeyController {
private static final Logger logger = LoggerFactory.getLogger(ApiKeyController.class);
@Value("${spring.profiles.active:dev}")
private String activeProfile;
@Value("${ai.api.key:}")
private String currentApiKey;
@Value("${jwt.expiration:86400000}")
private Long currentJwtExpiration;
/**
* 获取当前API密钥和JWT配置仅显示部分用于验证
*/
@GetMapping
public ResponseEntity<Map<String, Object>> getApiKey() {
try {
Map<String, Object> response = new HashMap<>();
// 只返回密钥的前4位和后4位中间用*代替
if (currentApiKey != null && currentApiKey.length() > 8) {
String masked = currentApiKey.substring(0, 4) + "****" + currentApiKey.substring(currentApiKey.length() - 4);
response.put("maskedKey", masked);
} else {
response.put("maskedKey", "****");
}
// 返回JWT过期时间毫秒
response.put("jwtExpiration", currentJwtExpiration);
// 转换为小时显示
response.put("jwtExpirationHours", currentJwtExpiration / 3600000.0);
response.put("success", true);
return ResponseEntity.ok(response);
} catch (Exception e) {
logger.error("获取API密钥失败", e);
Map<String, Object> error = new HashMap<>();
error.put("error", "获取API密钥失败");
error.put("message", e.getMessage());
return ResponseEntity.status(500).body(error);
}
}
/**
* 更新API密钥和JWT配置到配置文件
*/
@PutMapping
public ResponseEntity<Map<String, Object>> updateApiKey(@RequestBody Map<String, Object> request) {
try {
String newApiKey = (String) request.get("apiKey");
Object jwtExpirationObj = request.get("jwtExpiration");
// 验证API密钥
if (newApiKey != null && newApiKey.trim().isEmpty()) {
newApiKey = null; // 如果为空字符串,则不更新
}
// 验证JWT过期时间
Long newJwtExpiration = null;
if (jwtExpirationObj != null) {
if (jwtExpirationObj instanceof Number) {
newJwtExpiration = ((Number) jwtExpirationObj).longValue();
} else if (jwtExpirationObj instanceof String) {
try {
newJwtExpiration = Long.parseLong((String) jwtExpirationObj);
} catch (NumberFormatException e) {
Map<String, Object> error = new HashMap<>();
error.put("error", "JWT过期时间格式错误");
error.put("message", "JWT过期时间必须是数字毫秒");
return ResponseEntity.badRequest().body(error);
}
}
// 验证过期时间范围至少1小时最多30天
if (newJwtExpiration != null && (newJwtExpiration < 3600000 || newJwtExpiration > 2592000000L)) {
Map<String, Object> error = new HashMap<>();
error.put("error", "JWT过期时间超出范围");
error.put("message", "JWT过期时间必须在1小时3600000毫秒到30天2592000000毫秒之间");
return ResponseEntity.badRequest().body(error);
}
}
// 如果都没有提供,返回错误
if (newApiKey == null && newJwtExpiration == null) {
Map<String, Object> error = new HashMap<>();
error.put("error", "至少需要提供一个配置项");
error.put("message", "请提供API密钥或JWT过期时间");
return ResponseEntity.badRequest().body(error);
}
// 确定配置文件路径
String configFileName = "application-" + activeProfile + ".properties";
Path configPath = getConfigFilePath(configFileName);
// 读取现有配置
Properties props = new Properties();
if (Files.exists(configPath)) {
try (FileInputStream fis = new FileInputStream(configPath.toFile())) {
props.load(fis);
}
}
// 更新API密钥
if (newApiKey != null) {
props.setProperty("ai.api.key", newApiKey);
props.setProperty("ai.image.api.key", newApiKey); // 同时更新图片API密钥
logger.info("API密钥已更新");
}
// 更新JWT过期时间
if (newJwtExpiration != null) {
props.setProperty("jwt.expiration", String.valueOf(newJwtExpiration));
logger.info("JWT过期时间已更新: {} 毫秒 ({} 小时)", newJwtExpiration, newJwtExpiration / 3600000.0);
}
// 保存配置文件
try (FileOutputStream fos = new FileOutputStream(configPath.toFile())) {
props.store(fos, "Updated by API Key Management");
}
logger.info("配置已更新到配置文件: {}", configPath);
Map<String, Object> response = new HashMap<>();
response.put("success", true);
StringBuilder message = new StringBuilder();
if (newApiKey != null) {
message.append("API密钥已更新。");
}
if (newJwtExpiration != null) {
message.append("JWT过期时间已更新。");
}
message.append("请重启应用以使配置生效。");
response.put("message", message.toString());
return ResponseEntity.ok(response);
} catch (Exception e) {
logger.error("更新配置失败", e);
Map<String, Object> error = new HashMap<>();
error.put("error", "更新配置失败");
error.put("message", e.getMessage());
return ResponseEntity.status(500).body(error);
}
}
/**
* 获取配置文件路径
* 优先使用外部配置文件如果不存在则使用classpath中的配置文件
*/
private Path getConfigFilePath(String fileName) throws IOException {
// 尝试从外部配置目录查找
String externalConfigDir = System.getProperty("user.dir");
Path externalPath = Paths.get(externalConfigDir, "config", fileName);
if (Files.exists(externalPath)) {
return externalPath;
}
// 尝试从项目根目录查找
Path rootPath = Paths.get(externalConfigDir, "src", "main", "resources", fileName);
if (Files.exists(rootPath)) {
return rootPath;
}
// 尝试从classpath复制到外部目录
ClassPathResource resource = new ClassPathResource(fileName);
if (resource.exists()) {
// 创建config目录
Path configDir = Paths.get(externalConfigDir, "config");
Files.createDirectories(configDir);
// 复制文件到外部目录
Path targetPath = configDir.resolve(fileName);
try (InputStream is = resource.getInputStream();
FileOutputStream fos = new FileOutputStream(targetPath.toFile())) {
is.transferTo(fos);
}
return targetPath;
}
// 如果都不存在,创建新的配置文件
Path configDir = Paths.get(externalConfigDir, "config");
Files.createDirectories(configDir);
return configDir.resolve(fileName);
}
}

View File

@@ -1,19 +1,28 @@
package com.example.demo.controller;
import com.example.demo.repository.UserRepository;
import java.time.LocalDateTime;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.data.domain.PageRequest;
import org.springframework.http.ResponseEntity;
import org.springframework.web.bind.annotation.CrossOrigin;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;
import com.example.demo.model.Order;
import com.example.demo.repository.MembershipLevelRepository;
import com.example.demo.repository.OrderRepository;
import com.example.demo.repository.PaymentRepository;
import com.example.demo.repository.UserMembershipRepository;
import com.example.demo.repository.MembershipLevelRepository;
import com.example.demo.model.Order;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.http.ResponseEntity;
import org.springframework.web.bind.annotation.*;
import java.time.LocalDateTime;
import java.util.*;
import java.util.stream.Collectors;
import org.springframework.data.domain.PageRequest;
import com.example.demo.repository.UserRepository;
@RestController
@RequestMapping("/api/dashboard")
@@ -130,7 +139,7 @@ public class DashboardApiController {
// 获取用户转化率数据
@GetMapping("/conversion-rate")
public ResponseEntity<Map<String, Object>> getConversionRate() {
public ResponseEntity<Map<String, Object>> getConversionRate(@RequestParam(required = false) String year) {
try {
Map<String, Object> response = new HashMap<>();
@@ -147,6 +156,12 @@ public class DashboardApiController {
response.put("paidUsers", paidUsers);
response.put("conversionRate", Math.round(conversionRate * 100.0) / 100.0);
// 如果指定了年份,返回按月转化率数据
if (year != null && !year.isEmpty()) {
List<Map<String, Object>> monthlyConversion = getMonthlyConversionRate(Integer.parseInt(year));
response.put("monthlyData", monthlyConversion);
}
// 按会员等级统计
List<Map<String, Object>> membershipStats = membershipLevelRepository.findMembershipStats();
response.put("membershipStats", membershipStats);
@@ -160,6 +175,36 @@ public class DashboardApiController {
return ResponseEntity.status(500).body(error);
}
}
// 获取按月转化率数据
private List<Map<String, Object>> getMonthlyConversionRate(int year) {
List<Map<String, Object>> monthlyData = new ArrayList<>();
for (int month = 1; month <= 12; month++) {
Map<String, Object> monthData = new HashMap<>();
monthData.put("month", month);
// 计算该月的总用户数(注册时间在该月)
LocalDateTime monthStart = LocalDateTime.of(year, month, 1, 0, 0, 0);
LocalDateTime monthEnd = monthStart.plusMonths(1).minusSeconds(1);
long monthTotalUsers = userRepository.countByCreatedAtBetween(monthStart, monthEnd);
// 计算该月新增的付费用户数(会员开始时间在该月)
long monthPaidUsers = userMembershipRepository.countByStartDateBetween(monthStart, monthEnd);
// 计算该月转化率
double monthConversionRate = monthTotalUsers > 0 ? (double) monthPaidUsers / monthTotalUsers * 100 : 0.0;
monthData.put("totalUsers", monthTotalUsers);
monthData.put("paidUsers", monthPaidUsers);
monthData.put("conversionRate", Math.round(monthConversionRate * 100.0) / 100.0);
monthlyData.add(monthData);
}
return monthlyData;
}
// 获取最近订单数据
@GetMapping("/recent-orders")

View File

@@ -70,16 +70,17 @@ public class ImageToVideoApiController {
return ResponseEntity.badRequest().body(response);
}
// 验证文件大小最大10MB
if (firstFrame.getSize() > 10 * 1024 * 1024) {
// 验证文件大小最大100MB,与文件上传配置保持一致
long maxFileSize = 100 * 1024 * 1024; // 100MB
if (firstFrame.getSize() > maxFileSize) {
response.put("success", false);
response.put("message", "首帧图片大小不能超过10MB");
response.put("message", "首帧图片大小不能超过100MB");
return ResponseEntity.badRequest().body(response);
}
if (lastFrame != null && !lastFrame.isEmpty() && lastFrame.getSize() > 10 * 1024 * 1024) {
if (lastFrame != null && !lastFrame.isEmpty() && lastFrame.getSize() > maxFileSize) {
response.put("success", false);
response.put("message", "尾帧图片大小不能超过10MB");
response.put("message", "尾帧图片大小不能超过100MB");
return ResponseEntity.badRequest().body(response);
}
@@ -209,42 +210,6 @@ public class ImageToVideoApiController {
}
}
/**
* 取消任务
*/
@PostMapping("/tasks/{taskId}/cancel")
public ResponseEntity<Map<String, Object>> cancelTask(
@PathVariable String taskId,
@RequestHeader("Authorization") String token) {
Map<String, Object> response = new HashMap<>();
try {
String username = extractUsernameFromToken(token);
if (username == null) {
response.put("success", false);
response.put("message", "用户未登录");
return ResponseEntity.status(401).body(response);
}
boolean success = imageToVideoService.cancelTask(taskId, username);
if (success) {
response.put("success", true);
response.put("message", "任务已取消");
} else {
response.put("success", false);
response.put("message", "任务取消失败或任务不存在");
}
return ResponseEntity.ok(response);
} catch (Exception e) {
logger.error("取消任务失败", e);
response.put("success", false);
response.put("message", "取消任务失败:" + e.getMessage());
return ResponseEntity.status(500).body(response);
}
}
/**
* 获取任务状态

View File

@@ -1,23 +1,32 @@
package com.example.demo.controller;
import com.example.demo.model.User;
import com.example.demo.model.UserMembership;
import com.example.demo.model.MembershipLevel;
import com.example.demo.repository.UserRepository;
import com.example.demo.repository.UserMembershipRepository;
import com.example.demo.repository.MembershipLevelRepository;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.data.domain.Page;
import org.springframework.data.domain.PageRequest;
import org.springframework.data.domain.Pageable;
import org.springframework.data.domain.Sort;
import org.springframework.http.ResponseEntity;
import org.springframework.web.bind.annotation.*;
import org.springframework.web.bind.annotation.CrossOrigin;
import org.springframework.web.bind.annotation.DeleteMapping;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.PathVariable;
import org.springframework.web.bind.annotation.PutMapping;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import com.example.demo.model.MembershipLevel;
import com.example.demo.model.User;
import com.example.demo.model.UserMembership;
import com.example.demo.repository.MembershipLevelRepository;
import com.example.demo.repository.UserMembershipRepository;
import com.example.demo.repository.UserRepository;
@RestController
@RequestMapping("/api/members")
@@ -260,4 +269,97 @@ public class MemberApiController {
return ResponseEntity.status(500).body(error);
}
}
// 获取所有会员等级配置(用于系统设置和订阅页面)
@GetMapping("/levels")
public ResponseEntity<Map<String, Object>> getMembershipLevels() {
try {
List<MembershipLevel> levels = membershipLevelRepository.findAll();
List<Map<String, Object>> levelList = levels.stream()
.map(level -> {
Map<String, Object> levelMap = new HashMap<>();
levelMap.put("id", level.getId());
levelMap.put("name", level.getName());
levelMap.put("displayName", level.getDisplayName());
levelMap.put("description", level.getDescription());
levelMap.put("price", level.getPrice());
levelMap.put("durationDays", level.getDurationDays());
levelMap.put("pointsBonus", level.getPointsBonus());
levelMap.put("features", level.getFeatures());
levelMap.put("isActive", level.getIsActive());
return levelMap;
})
.toList();
Map<String, Object> response = new HashMap<>();
response.put("success", true);
response.put("data", levelList);
return ResponseEntity.ok(response);
} catch (Exception e) {
Map<String, Object> error = new HashMap<>();
error.put("error", "获取会员等级配置失败");
error.put("message", e.getMessage());
return ResponseEntity.status(500).body(error);
}
}
// 更新会员等级价格和配置
@PutMapping("/levels/{id}")
public ResponseEntity<Map<String, Object>> updateMembershipLevel(
@PathVariable Long id,
@RequestBody Map<String, Object> updateData) {
try {
Optional<MembershipLevel> levelOpt = membershipLevelRepository.findById(id);
if (levelOpt.isEmpty()) {
return ResponseEntity.notFound().build();
}
MembershipLevel level = levelOpt.get();
// 更新价格
if (updateData.containsKey("price")) {
Object priceObj = updateData.get("price");
if (priceObj instanceof Number) {
level.setPrice(((Number) priceObj).doubleValue());
} else if (priceObj instanceof String) {
level.setPrice(Double.parseDouble((String) priceObj));
}
}
// 更新资源点数量
if (updateData.containsKey("pointsBonus") || updateData.containsKey("resourcePoints")) {
Object pointsObj = updateData.get("pointsBonus") != null
? updateData.get("pointsBonus")
: updateData.get("resourcePoints");
if (pointsObj instanceof Number) {
level.setPointsBonus(((Number) pointsObj).intValue());
} else if (pointsObj instanceof String) {
level.setPointsBonus(Integer.parseInt((String) pointsObj));
}
}
// 更新描述
if (updateData.containsKey("description")) {
level.setDescription((String) updateData.get("description"));
}
level.setUpdatedAt(java.time.LocalDateTime.now());
membershipLevelRepository.save(level);
Map<String, Object> response = new HashMap<>();
response.put("success", true);
response.put("message", "会员等级配置更新成功");
return ResponseEntity.ok(response);
} catch (Exception e) {
Map<String, Object> error = new HashMap<>();
error.put("error", "更新会员等级配置失败");
error.put("message", e.getMessage());
return ResponseEntity.status(500).body(error);
}
}
}

View File

@@ -1,8 +1,9 @@
package com.example.demo.controller;
import com.example.demo.model.*;
import com.example.demo.service.OrderService;
import jakarta.validation.Valid;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
@@ -12,11 +13,22 @@ import org.springframework.data.domain.Pageable;
import org.springframework.data.domain.Sort;
import org.springframework.http.ResponseEntity;
import org.springframework.security.core.Authentication;
import org.springframework.web.bind.annotation.*;
import org.springframework.web.bind.annotation.DeleteMapping;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.PathVariable;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import com.example.demo.model.Order;
import com.example.demo.model.OrderStatus;
import com.example.demo.model.PaymentMethod;
import com.example.demo.model.User;
import com.example.demo.service.OrderService;
import jakarta.validation.Valid;
@RestController
@RequestMapping("/api/orders")
@@ -90,14 +102,38 @@ public class OrderApiController {
public ResponseEntity<Map<String, Object>> getOrderById(@PathVariable Long id,
Authentication authentication) {
try {
// 检查认证信息
if (authentication == null || !authentication.isAuthenticated()) {
Map<String, Object> response = new HashMap<>();
response.put("success", false);
response.put("message", "用户未认证,请重新登录");
return ResponseEntity.status(401).body(response);
}
User user = (User) authentication.getPrincipal();
if (user == null) {
Map<String, Object> response = new HashMap<>();
response.put("success", false);
response.put("message", "用户信息获取失败,请重新登录");
return ResponseEntity.status(401).body(response);
}
Order order = orderService.findById(id)
.orElseThrow(() -> new RuntimeException("订单不存在"));
.orElse(null);
if (order == null) {
Map<String, Object> response = new HashMap<>();
response.put("success", false);
response.put("message", "订单不存在");
return ResponseEntity.status(404).body(response);
}
// 检查权限
if (!user.getRole().equals("ROLE_ADMIN") && !order.getUser().getId().equals(user.getId())) {
return ResponseEntity.badRequest()
.body(createErrorResponse("无权限访问此订单"));
Map<String, Object> response = new HashMap<>();
response.put("success", false);
response.put("message", "无权限访问此订单");
return ResponseEntity.status(403).body(response);
}
Map<String, Object> response = new HashMap<>();
@@ -108,8 +144,10 @@ public class OrderApiController {
} catch (Exception e) {
logger.error("获取订单详情失败:", e);
return ResponseEntity.badRequest()
.body(createErrorResponse("获取订单详情失败:" + e.getMessage()));
Map<String, Object> response = new HashMap<>();
response.put("success", false);
response.put("message", "获取订单详情失败:" + e.getMessage());
return ResponseEntity.status(500).body(response);
}
}

View File

@@ -399,16 +399,25 @@ public class PaymentApiController {
.body(createErrorResponse("用户不存在"));
}
// 获取用户最近一次成功的订阅支付
logger.info("开始查询用户订阅记录用户ID: {}", user.getId());
List<Payment> subscriptions;
// 获取用户最近一次成功的支付记录(包括所有充值记录,不仅仅是订阅)
logger.info("开始查询用户支付记录用户ID: {}", user.getId());
List<Payment> allPayments;
try {
subscriptions = paymentRepository.findLatestSuccessfulSubscriptionByUserId(user.getId(), PaymentStatus.SUCCESS);
logger.info("用户 {} (ID: {}) 的订阅记录数量: {}", username, user.getId(), subscriptions.size());
// 获取所有成功支付的记录,按支付时间倒序
allPayments = paymentRepository.findByUserIdOrderByCreatedAtDesc(user.getId())
.stream()
.filter(p -> p.getStatus() == PaymentStatus.SUCCESS)
.sorted((p1, p2) -> {
LocalDateTime time1 = p1.getPaidAt() != null ? p1.getPaidAt() : p1.getCreatedAt();
LocalDateTime time2 = p2.getPaidAt() != null ? p2.getPaidAt() : p2.getCreatedAt();
return time2.compareTo(time1); // 倒序
})
.collect(java.util.stream.Collectors.toList());
logger.info("用户 {} (ID: {}) 的成功支付记录数量: {}", username, user.getId(), allPayments.size());
} catch (Exception e) {
logger.error("查询订阅记录失败用户ID: {}", user.getId(), e);
logger.error("查询支付记录失败用户ID: {}", user.getId(), e);
// 如果查询失败,使用空列表
subscriptions = new ArrayList<>();
allPayments = new ArrayList<>();
}
Map<String, Object> subscriptionInfo = new HashMap<>();
@@ -418,26 +427,58 @@ public class PaymentApiController {
String expiryTime = "永久";
LocalDateTime paidAt = null;
if (!subscriptions.isEmpty()) {
logger.info("找到订阅记录,第一条描述: {}", subscriptions.get(0).getDescription());
Payment latestSubscription = subscriptions.get(0);
String description = latestSubscription.getDescription();
paidAt = latestSubscription.getPaidAt() != null ?
latestSubscription.getPaidAt() : latestSubscription.getCreatedAt();
// 使用最近的充值记录来确定会员权益
if (!allPayments.isEmpty()) {
Payment latestPayment = allPayments.get(0);
String description = latestPayment.getDescription();
paidAt = latestPayment.getPaidAt() != null ?
latestPayment.getPaidAt() : latestPayment.getCreatedAt();
// 从描述中识别套餐类型
logger.info("使用最近的支付记录ID: {}, 描述: {}, 金额: {}, 支付时间: {}",
latestPayment.getId(), description, latestPayment.getAmount(), paidAt);
// 从描述或金额中识别套餐类型
if (description != null) {
if (description.contains("标准版")) {
if (description.contains("标准版") || description.contains("standard") ||
description.contains("Standard") || description.contains("STANDARD")) {
currentPlan = "标准版会员";
} else if (description.contains("专业版")) {
} else if (description.contains("专业版") || description.contains("premium") ||
description.contains("Premium") || description.contains("PREMIUM")) {
currentPlan = "专业版会员";
} else if (description.contains("会员")) {
currentPlan = "会员";
} else {
// 如果描述中没有套餐信息,根据金额判断
java.math.BigDecimal amount = latestPayment.getAmount();
if (amount != null) {
// 标准版订阅 (59-258元) - 200积分
if (amount.compareTo(new java.math.BigDecimal("59.00")) >= 0 &&
amount.compareTo(new java.math.BigDecimal("259.00")) < 0) {
currentPlan = "标准版会员";
logger.info("根据金额 {} 判断为标准版会员", amount);
}
// 专业版订阅 (259元以上) - 1000积分
else if (amount.compareTo(new java.math.BigDecimal("259.00")) >= 0) {
currentPlan = "专业版会员";
logger.info("根据金额 {} 判断为专业版会员", amount);
}
}
}
} else {
// 如果描述为空,根据金额判断
java.math.BigDecimal amount = latestPayment.getAmount();
if (amount != null) {
if (amount.compareTo(new java.math.BigDecimal("59.00")) >= 0 &&
amount.compareTo(new java.math.BigDecimal("259.00")) < 0) {
currentPlan = "标准版会员";
} else if (amount.compareTo(new java.math.BigDecimal("259.00")) >= 0) {
currentPlan = "专业版会员";
}
}
}
// 计算到期时间假设订阅有效期为30天
if (paidAt != null) {
if (paidAt != null && !currentPlan.equals("免费版")) {
LocalDateTime expiryDateTime = paidAt.plusDays(30);
LocalDateTime now = LocalDateTime.now();
@@ -447,6 +488,8 @@ public class PaymentApiController {
} else {
// 已过期,显示已过期
expiryTime = "已过期";
// 如果已过期,恢复为免费版
currentPlan = "免费版";
}
}
}
@@ -460,11 +503,25 @@ public class PaymentApiController {
subscriptionInfo.put("email", user.getEmail());
subscriptionInfo.put("nickname", user.getNickname());
logger.info("=== 用户订阅信息 ===");
logger.info("当前套餐: {}", currentPlan);
logger.info("到期时间: {}", expiryTime);
logger.info("支付时间: {}", paidAt);
logger.info("积分: {}", user.getPoints());
logger.info("成功支付记录数: {}", allPayments.size());
if (!allPayments.isEmpty()) {
logger.info("最近支付记录: ID={}, 描述={}, 金额={}, 时间={}",
allPayments.get(0).getId(),
allPayments.get(0).getDescription(),
allPayments.get(0).getAmount(),
allPayments.get(0).getPaidAt() != null ? allPayments.get(0).getPaidAt() : allPayments.get(0).getCreatedAt());
}
Map<String, Object> response = new HashMap<>();
response.put("success", true);
response.put("data", subscriptionInfo);
logger.info("=== 用户订阅信息获取成功 ===");
logger.info("=== 用户订阅信息获取成功,返回数据: {} ===", subscriptionInfo);
return ResponseEntity.ok(response);
} catch (Exception e) {
logger.error("获取用户订阅信息失败", e);

View File

@@ -1,18 +1,24 @@
package com.example.demo.controller;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.http.ResponseEntity;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestHeader;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;
import com.example.demo.model.PointsFreezeRecord;
import com.example.demo.model.User;
import com.example.demo.service.UserService;
import com.example.demo.util.JwtUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.http.ResponseEntity;
import org.springframework.web.bind.annotation.*;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
/**
* 积分冻结API控制器
@@ -101,6 +107,43 @@ public class PointsApiController {
}
}
/**
* 获取积分使用历史(充值和使用记录)
*/
@GetMapping("/history")
public ResponseEntity<Map<String, Object>> getPointsHistory(
@RequestHeader("Authorization") String token,
@RequestParam(defaultValue = "0") int page,
@RequestParam(defaultValue = "50") int size) {
Map<String, Object> response = new HashMap<>();
try {
String username = extractUsernameFromToken(token);
if (username == null) {
response.put("success", false);
response.put("message", "用户未登录");
return ResponseEntity.status(401).body(response);
}
// 获取积分使用历史
List<Map<String, Object>> history = userService.getPointsHistory(username, page, size);
response.put("success", true);
response.put("data", history);
response.put("page", page);
response.put("size", size);
return ResponseEntity.ok(response);
} catch (Exception e) {
logger.error("获取积分使用历史失败", e);
response.put("success", false);
response.put("message", "获取积分使用历史失败:" + e.getMessage());
return ResponseEntity.status(500).body(response);
}
}
/**
* 手动处理过期冻结记录(管理员功能)
*/

View File

@@ -148,4 +148,43 @@ public class StoryboardVideoApiController {
.body(Map.of("success", false, "message", "查询失败"));
}
}
/**
* 开始生成视频(从分镜图生成视频)
* 用户点击"开始生成"按钮后调用
*/
@PostMapping("/task/{taskId}/start-video")
public ResponseEntity<?> startVideoGeneration(
@PathVariable String taskId,
Authentication authentication) {
try {
String username = authentication.getName();
logger.info("收到开始生成视频请求任务ID: {}, 用户: {}", taskId, username);
// 验证任务是否存在且属于该用户
StoryboardVideoTask task = storyboardVideoService.getTask(taskId);
if (!task.getUsername().equals(username)) {
logger.warn("用户 {} 尝试访问任务 {},但任务属于用户 {}", username, taskId, task.getUsername());
return ResponseEntity.status(403)
.body(Map.of("success", false, "message", "无权访问此任务"));
}
// 开始生成视频
storyboardVideoService.startVideoGeneration(taskId);
return ResponseEntity.ok(Map.of(
"success", true,
"message", "视频生成任务已启动"
));
} catch (RuntimeException e) {
logger.error("开始生成视频失败: {}", e.getMessage());
return ResponseEntity.badRequest()
.body(Map.of("success", false, "message", e.getMessage()));
} catch (Exception e) {
logger.error("开始生成视频异常", e);
return ResponseEntity.internalServerError()
.body(Map.of("success", false, "message", "启动视频生成失败"));
}
}
}

View File

@@ -239,39 +239,6 @@ public class TextToVideoApiController {
}
}
/**
* 取消文生视频任务
*/
@PostMapping("/tasks/{taskId}/cancel")
public ResponseEntity<Map<String, Object>> cancelTask(
@PathVariable String taskId,
@RequestHeader("Authorization") String token) {
Map<String, Object> response = new HashMap<>();
try {
String username = extractUsernameFromToken(token);
if (username == null) {
response.put("success", false);
response.put("message", "用户未登录");
return ResponseEntity.status(401).body(response);
}
boolean cancelled = textToVideoService.cancelTask(taskId, username);
if (cancelled) {
response.put("success", true);
response.put("message", "任务已取消");
} else {
response.put("success", false);
response.put("message", "任务取消失败或任务不存在/无权限");
}
return ResponseEntity.ok(response);
} catch (Exception e) {
logger.error("取消任务失败", e);
response.put("success", false);
response.put("message", "取消任务失败:" + e.getMessage());
return ResponseEntity.status(500).body(response);
}
}
/**
* 从Token中提取用户名

View File

@@ -72,3 +72,6 @@ public class MailMessage {

View File

@@ -1,8 +1,16 @@
package com.example.demo.model;
import jakarta.persistence.*;
import java.time.LocalDateTime;
import jakarta.persistence.Column;
import jakarta.persistence.Entity;
import jakarta.persistence.EnumType;
import jakarta.persistence.Enumerated;
import jakarta.persistence.GeneratedValue;
import jakarta.persistence.GenerationType;
import jakarta.persistence.Id;
import jakarta.persistence.Table;
/**
* 积分冻结记录实体
* 记录每次积分冻结的详细信息
@@ -49,7 +57,8 @@ public class PointsFreezeRecord {
*/
public enum TaskType {
TEXT_TO_VIDEO("文生视频"),
IMAGE_TO_VIDEO("图生视频");
IMAGE_TO_VIDEO("图生视频"),
STORYBOARD_VIDEO("分镜视频");
private final String description;
@@ -206,3 +215,5 @@ public class PointsFreezeRecord {

View File

@@ -47,11 +47,20 @@ public class StoryboardVideoTask {
@Column(nullable = false)
private int progress; // 0-100
@Column(name = "result_url", columnDefinition = "TEXT")
private String resultUrl; // 分镜图URL
@Column(name = "result_url", columnDefinition = "LONGTEXT")
private String resultUrl; // 分镜图URLBase64编码的图片可能非常大- 网格图
@Column(name = "storyboard_images", columnDefinition = "LONGTEXT")
private String storyboardImages; // 单独的分镜图片JSON数组每张图片为Base64格式带data URI前缀
@Column(name = "real_task_id")
private String realTaskId;
private String realTaskId; // 主任务ID已废弃保留用于兼容
@Column(name = "video_task_ids", columnDefinition = "TEXT")
private String videoTaskIds; // 多个视频任务IDJSON数组每张图片对应一个视频任务
@Column(name = "video_urls", columnDefinition = "LONGTEXT")
private String videoUrls; // 多个视频URLJSON数组用于拼接
@Column(columnDefinition = "TEXT")
private String errorMessage;
@@ -141,10 +150,16 @@ public class StoryboardVideoTask {
public void setProgress(int progress) { this.progress = progress; }
public String getResultUrl() { return resultUrl; }
public void setResultUrl(String resultUrl) { this.resultUrl = resultUrl; }
public String getStoryboardImages() { return storyboardImages; }
public void setStoryboardImages(String storyboardImages) { this.storyboardImages = storyboardImages; }
public String getErrorMessage() { return errorMessage; }
public void setErrorMessage(String errorMessage) { this.errorMessage = errorMessage; }
public String getRealTaskId() { return realTaskId; }
public void setRealTaskId(String realTaskId) { this.realTaskId = realTaskId; }
public String getVideoTaskIds() { return videoTaskIds; }
public void setVideoTaskIds(String videoTaskIds) { this.videoTaskIds = videoTaskIds; }
public String getVideoUrls() { return videoUrls; }
public void setVideoUrls(String videoUrls) { this.videoUrls = videoUrls; }
public int getCostPoints() { return costPoints; }
public void setCostPoints(int costPoints) { this.costPoints = costPoints; }
public LocalDateTime getCreatedAt() { return createdAt; }

View File

@@ -1,8 +1,16 @@
package com.example.demo.model;
import jakarta.persistence.*;
import java.time.LocalDateTime;
import jakarta.persistence.Column;
import jakarta.persistence.Entity;
import jakarta.persistence.EnumType;
import jakarta.persistence.Enumerated;
import jakarta.persistence.GeneratedValue;
import jakarta.persistence.GenerationType;
import jakarta.persistence.Id;
import jakarta.persistence.Table;
/**
* 任务队列实体
* 用于管理用户的视频生成任务队列
@@ -42,7 +50,7 @@ public class TaskQueue {
private Integer checkCount = 0; // 检查次数
@Column(name = "max_check_count", nullable = false)
private Integer maxCheckCount = 30; // 最大检查次数(30次 * 2分钟 = 60分钟
private Integer maxCheckCount = 5; // 最大检查次数(5次 * 2分钟 = 10分钟
@Column(name = "error_message", columnDefinition = "TEXT")
private String errorMessage;
@@ -61,7 +69,8 @@ public class TaskQueue {
*/
public enum TaskType {
TEXT_TO_VIDEO("文生视频"),
IMAGE_TO_VIDEO("图生视频");
IMAGE_TO_VIDEO("图生视频"),
STORYBOARD_VIDEO("分镜视频");
private final String description;
@@ -274,3 +283,5 @@ public class TaskQueue {

View File

@@ -266,3 +266,6 @@ public class TaskStatus {

View File

@@ -65,6 +65,9 @@ public class User {
@Column(name = "address", columnDefinition = "TEXT")
private String address;
@Column(name = "bio", columnDefinition = "TEXT")
private String bio; // 个人简介
@Column(name = "is_active", nullable = false)
private Boolean isActive = true;
@@ -218,6 +221,14 @@ public class User {
this.address = address;
}
public String getBio() {
return bio;
}
public void setBio(String bio) {
this.bio = bio;
}
public Boolean getIsActive() {
return isActive;
}

View File

@@ -99,7 +99,8 @@ public class UserWork {
*/
public enum WorkType {
TEXT_TO_VIDEO("文生视频"),
IMAGE_TO_VIDEO("图生视频");
IMAGE_TO_VIDEO("图生视频"),
STORYBOARD_VIDEO("分镜视频");
private final String description;

View File

@@ -1,6 +1,9 @@
package com.example.demo.repository;
import com.example.demo.model.ImageToVideoTask;
import java.time.LocalDateTime;
import java.util.List;
import java.util.Optional;
import org.springframework.data.domain.Page;
import org.springframework.data.domain.Pageable;
import org.springframework.data.jpa.repository.JpaRepository;
@@ -9,8 +12,7 @@ import org.springframework.data.jpa.repository.Query;
import org.springframework.data.repository.query.Param;
import org.springframework.stereotype.Repository;
import java.util.List;
import java.util.Optional;
import com.example.demo.model.ImageToVideoTask;
/**
* 图生视频任务数据访问层
@@ -85,4 +87,15 @@ public interface ImageToVideoTaskRepository extends JpaRepository<ImageToVideoTa
@Modifying
@Query("DELETE FROM ImageToVideoTask t WHERE t.status = :status")
int deleteByStatus(@Param("status") String status);
/**
* 查找超时的图生视频任务
* 条件状态为PROCESSING且创建时间超过10分钟
*/
@Query("SELECT t FROM ImageToVideoTask t WHERE t.status = :status " +
"AND t.createdAt < :timeoutTime")
List<ImageToVideoTask> findTimeoutTasks(
@Param("status") ImageToVideoTask.TaskStatus status,
@Param("timeoutTime") LocalDateTime timeoutTime
);
}

View File

@@ -1,11 +1,14 @@
package com.example.demo.repository;
import java.time.LocalDateTime;
import java.util.List;
import java.util.Optional;
import org.springframework.data.domain.Page;
import org.springframework.data.domain.Pageable;
import org.springframework.data.jpa.repository.JpaRepository;
import org.springframework.data.jpa.repository.Query;
import org.springframework.data.repository.query.Param;
import org.springframework.stereotype.Repository;
import com.example.demo.model.StoryboardVideoTask;
@@ -20,4 +23,16 @@ public interface StoryboardVideoTaskRepository extends JpaRepository<StoryboardV
Page<StoryboardVideoTask> findByUsernameOrderByCreatedAtDesc(String username, Pageable pageable);
List<StoryboardVideoTask> findByStatus(StoryboardVideoTask.TaskStatus status);
/**
* 查找超时的分镜视频任务
* 条件状态为PROCESSINGrealTaskId为空说明还在生成分镜图阶段且创建时间超过10分钟
*/
@Query("SELECT t FROM StoryboardVideoTask t WHERE t.status = :status " +
"AND (t.realTaskId IS NULL OR t.realTaskId = '') " +
"AND t.createdAt < :timeoutTime")
List<StoryboardVideoTask> findTimeoutTasks(
@Param("status") StoryboardVideoTask.TaskStatus status,
@Param("timeoutTime") LocalDateTime timeoutTime
);
}

View File

@@ -7,6 +7,7 @@ import java.util.Optional;
import org.springframework.data.domain.Page;
import org.springframework.data.domain.Pageable;
import org.springframework.data.jpa.repository.JpaRepository;
import org.springframework.data.jpa.repository.Lock;
import org.springframework.data.jpa.repository.Modifying;
import org.springframework.data.jpa.repository.Query;
import org.springframework.data.repository.query.Param;
@@ -14,6 +15,8 @@ import org.springframework.stereotype.Repository;
import com.example.demo.model.TaskQueue;
import jakarta.persistence.LockModeType;
/**
* 任务队列仓库接口
*/
@@ -36,6 +39,14 @@ public interface TaskQueueRepository extends JpaRepository<TaskQueue, Long> {
* 根据任务ID查找队列任务
*/
Optional<TaskQueue> findByTaskId(String taskId);
/**
* 使用悲观锁查找任务SELECT FOR UPDATE
* 用于防止并发处理同一任务
*/
@Query("SELECT tq FROM TaskQueue tq WHERE tq.taskId = :taskId")
@Lock(LockModeType.PESSIMISTIC_WRITE)
Optional<TaskQueue> findByTaskIdWithLock(@Param("taskId") String taskId);
/**
* 根据用户名和任务ID查找队列任务

View File

@@ -74,3 +74,6 @@ public interface TaskStatusRepository extends JpaRepository<TaskStatus, Long> {

View File

@@ -1,6 +1,9 @@
package com.example.demo.repository;
import com.example.demo.model.TextToVideoTask;
import java.time.LocalDateTime;
import java.util.List;
import java.util.Optional;
import org.springframework.data.domain.Page;
import org.springframework.data.domain.Pageable;
import org.springframework.data.jpa.repository.JpaRepository;
@@ -9,8 +12,7 @@ import org.springframework.data.jpa.repository.Query;
import org.springframework.data.repository.query.Param;
import org.springframework.stereotype.Repository;
import java.util.List;
import java.util.Optional;
import com.example.demo.model.TextToVideoTask;
/**
* 文生视频任务Repository
@@ -90,5 +92,15 @@ public interface TextToVideoTaskRepository extends JpaRepository<TextToVideoTask
@Modifying
@Query("DELETE FROM TextToVideoTask t WHERE t.status = :status")
int deleteByStatus(@Param("status") String status);
/**
* 查找超时的文生视频任务
* 条件状态为PROCESSING且创建时间超过10分钟
*/
@Query("SELECT t FROM TextToVideoTask t WHERE t.status = :status " +
"AND t.createdAt < :timeoutTime")
List<TextToVideoTask> findTimeoutTasks(
@Param("status") TextToVideoTask.TaskStatus status,
@Param("timeoutTime") LocalDateTime timeoutTime
);
}

View File

@@ -1,14 +1,18 @@
package com.example.demo.repository;
import com.example.demo.model.UserMembership;
import java.time.LocalDateTime;
import java.util.Optional;
import org.springframework.data.jpa.repository.JpaRepository;
import org.springframework.stereotype.Repository;
import java.util.Optional;
import com.example.demo.model.UserMembership;
@Repository
public interface UserMembershipRepository extends JpaRepository<UserMembership, Long> {
Optional<UserMembership> findByUserIdAndStatus(Long userId, String status);
long countByStatus(String status);
long countByStartDateBetween(LocalDateTime startDate, LocalDateTime endDate);
}

View File

@@ -1,5 +1,6 @@
package com.example.demo.repository;
import java.time.LocalDateTime;
import java.util.Optional;
import org.springframework.data.jpa.repository.JpaRepository;
@@ -13,6 +14,7 @@ public interface UserRepository extends JpaRepository<User, Long> {
boolean existsByUsername(String username);
boolean existsByEmail(String email);
boolean existsByPhone(String phone);
long countByCreatedAtBetween(LocalDateTime startDate, LocalDateTime endDate);
}

View File

@@ -8,8 +8,11 @@ import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.scheduling.annotation.Scheduled;
import org.springframework.stereotype.Component;
import com.example.demo.service.ImageToVideoService;
import com.example.demo.service.StoryboardVideoService;
import com.example.demo.service.TaskCleanupService;
import com.example.demo.service.TaskQueueService;
import com.example.demo.service.TextToVideoService;
/**
* 任务队列定时调度器
@@ -25,6 +28,15 @@ public class TaskQueueScheduler {
@Autowired
private TaskCleanupService taskCleanupService;
@Autowired
private StoryboardVideoService storyboardVideoService;
@Autowired
private TextToVideoService textToVideoService;
@Autowired
private ImageToVideoService imageToVideoService;
/**
* 处理待处理任务
@@ -48,13 +60,62 @@ public class TaskQueueScheduler {
@Scheduled(fixedRate = 120000) // 2分钟 = 120000毫秒
public void checkTaskStatuses() {
try {
logger.info("=== 开始执行任务队列状态轮询查询 (每2分钟) ===");
taskQueueService.checkTaskStatuses();
logger.info("=== 任务队列状态轮询查询完成 ===");
} catch (Exception e) {
logger.error("检查任务状态失败", e);
}
}
/**
* 检查分镜图生成超时任务
* 每2分钟执行一次检查长时间处于PROCESSING状态但没有realTaskId的分镜视频任务
* 如果创建时间超过10分钟则标记为超时失败
*/
@Scheduled(fixedRate = 120000) // 2分钟 = 120000毫秒
public void checkStoryboardImageGenerationTimeout() {
try {
int handledCount = storyboardVideoService.checkAndHandleTimeoutTasks();
if (handledCount > 0) {
logger.warn("处理了 {} 个超时的分镜图生成任务", handledCount);
}
} catch (Exception e) {
logger.error("检查分镜图生成超时任务失败", e);
}
}
/**
* 检查文生视频超时任务
* 每2分钟执行一次检查长时间处于PROCESSING状态的文生视频任务
* 如果创建时间超过10分钟则标记为超时失败
*/
@Scheduled(fixedRate = 120000) // 2分钟 = 120000毫秒
public void checkTextToVideoTimeout() {
try {
int handledCount = textToVideoService.checkAndHandleTimeoutTasks();
if (handledCount > 0) {
logger.warn("处理了 {} 个超时的文生视频任务", handledCount);
}
} catch (Exception e) {
logger.error("检查文生视频超时任务失败", e);
}
}
/**
* 检查图生视频超时任务
* 每2分钟执行一次检查长时间处于PROCESSING状态的图生视频任务
* 如果创建时间超过10分钟则标记为超时失败
*/
@Scheduled(fixedRate = 120000) // 2分钟 = 120000毫秒
public void checkImageToVideoTimeout() {
try {
int handledCount = imageToVideoService.checkAndHandleTimeoutTasks();
if (handledCount > 0) {
logger.warn("处理了 {} 个超时的图生视频任务", handledCount);
}
} catch (Exception e) {
logger.error("检查图生视频超时任务失败", e);
}
}
/**
* 清理过期任务

View File

@@ -64,8 +64,9 @@ public class JwtAuthenticationFilter extends OncePerRequestFilter {
List<GrantedAuthority> authorities = new ArrayList<>();
authorities.add(new SimpleGrantedAuthority(user.getRole()));
// 将User对象作为Principal而不是用户名字符串
UsernamePasswordAuthenticationToken authToken =
new UsernamePasswordAuthenticationToken(user.getUsername(), null, authorities);
new UsernamePasswordAuthenticationToken(user, null, authorities);
authToken.setDetails(new WebAuthenticationDetailsSource().buildDetails(request));
SecurityContextHolder.getContext().setAuthentication(authToken);
logger.debug("JWT认证成功用户: {}, 角色: {}", username, user.getRole());

View File

@@ -41,6 +41,9 @@ public class PlainTextPasswordEncoder implements PasswordEncoder {

View File

@@ -1,6 +1,7 @@
package com.example.demo.service;
import java.awt.Graphics2D;
import java.awt.RenderingHints;
import java.awt.image.BufferedImage;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
@@ -101,15 +102,31 @@ public class ImageGridService {
g.dispose();
// 转换为Base64
// 压缩网格图片以减小体积(限制最大尺寸)
BufferedImage compressedGrid = compressGridImage(gridImage, 2048); // 网格图最大2048px
// 转换为JPEG格式的Base64JPEG压缩率更高体积更小
ByteArrayOutputStream baos = new ByteArrayOutputStream();
ImageIO.write(gridImage, "PNG", baos);
javax.imageio.ImageWriter writer = javax.imageio.ImageIO.getImageWritersByFormatName("jpg").next();
javax.imageio.ImageWriteParam param = writer.getDefaultWriteParam();
if (param.canWriteCompressed()) {
param.setCompressionMode(javax.imageio.ImageWriteParam.MODE_EXPLICIT);
param.setCompressionQuality(0.85f); // JPEG质量85%
}
javax.imageio.IIOImage iioImage = new javax.imageio.IIOImage(compressedGrid, null, null);
writer.setOutput(javax.imageio.ImageIO.createImageOutputStream(baos));
writer.write(null, iioImage, param);
writer.dispose();
byte[] imageBytes = baos.toByteArray();
String base64 = Base64.getEncoder().encodeToString(imageBytes);
logger.info("图片网格拼接完成: 总尺寸={}x{}", gridImage.getWidth(), gridImage.getHeight());
logger.info("图片网格拼接完成: 原始尺寸={}x{}, 压缩后尺寸={}x{}, 大小={} KB",
gridImage.getWidth(), gridImage.getHeight(),
compressedGrid.getWidth(), compressedGrid.getHeight(),
imageBytes.length / 1024);
return "data:image/png;base64," + base64;
return "data:image/jpeg;base64," + base64;
} catch (Exception e) {
logger.error("拼接图片网格失败", e);
@@ -130,6 +147,44 @@ public class ImageGridService {
return 4; // 默认4列
}
/**
* 压缩网格图片以减小体积
* @param originalImage 原始图片
* @param maxSize 最大尺寸(宽度或高度)
* @return 压缩后的图片
*/
private BufferedImage compressGridImage(BufferedImage originalImage, int maxSize) {
int originalWidth = originalImage.getWidth();
int originalHeight = originalImage.getHeight();
// 如果图片尺寸小于等于最大尺寸,直接返回
if (originalWidth <= maxSize && originalHeight <= maxSize) {
return originalImage;
}
// 计算缩放比例
double scale = Math.min((double) maxSize / originalWidth, (double) maxSize / originalHeight);
int newWidth = (int) (originalWidth * scale);
int newHeight = (int) (originalHeight * scale);
logger.debug("压缩网格图片: {}x{} -> {}x{} (缩放比例: {})",
originalWidth, originalHeight, newWidth, newHeight, scale);
// 创建缩放后的图片
BufferedImage compressedImage = new BufferedImage(newWidth, newHeight, BufferedImage.TYPE_INT_RGB);
Graphics2D g = compressedImage.createGraphics();
// 设置高质量缩放
g.setRenderingHint(RenderingHints.KEY_INTERPOLATION, RenderingHints.VALUE_INTERPOLATION_BILINEAR);
g.setRenderingHint(RenderingHints.KEY_RENDERING, RenderingHints.VALUE_RENDER_QUALITY);
g.setRenderingHint(RenderingHints.KEY_ANTIALIASING, RenderingHints.VALUE_ANTIALIAS_ON);
g.drawImage(originalImage, 0, 0, newWidth, newHeight, null);
g.dispose();
return compressedImage;
}
/**
* 从URL加载图片
*/

View File

@@ -27,9 +27,9 @@ import com.example.demo.repository.ImageToVideoTaskRepository;
/**
* 图生视频服务类
* 注意:不在类级别使用 @Transactional因为某些方法需要禁用事务如长时间运行的外部API调用
*/
@Service
@Transactional
public class ImageToVideoService {
private static final Logger logger = LoggerFactory.getLogger(ImageToVideoService.class);
@@ -52,6 +52,7 @@ public class ImageToVideoService {
/**
* 创建图生视频任务
*/
@Transactional
public ImageToVideoTask createTask(String username, MultipartFile firstFrame,
MultipartFile lastFrame, String prompt,
String aspectRatio, int duration, boolean hdMode) {
@@ -136,43 +137,36 @@ public class ImageToVideoService {
return taskRepository.findByTaskId(taskId).orElse(null);
}
/**
* 取消任务
*/
@Transactional
public boolean cancelTask(String taskId, String username) {
// 使用悲观锁避免并发问题
ImageToVideoTask task = taskRepository.findByTaskId(taskId).orElse(null);
if (task == null || task.getUsername() == null || !task.getUsername().equals(username)) {
return false;
}
// 检查任务状态只有PENDING和PROCESSING状态的任务才能取消
if (task.getStatus() == ImageToVideoTask.TaskStatus.PENDING ||
task.getStatus() == ImageToVideoTask.TaskStatus.PROCESSING) {
task.updateStatus(ImageToVideoTask.TaskStatus.CANCELLED);
task.setErrorMessage("用户取消了任务");
taskRepository.save(task);
logger.info("图生视频任务已取消: taskId={}, username={}", taskId, username);
return true;
}
return false;
}
/**
* 使用真实API处理任务
*/
@Async
@Async("taskExecutor")
public CompletableFuture<Void> processTaskWithRealAPI(ImageToVideoTask task, MultipartFile firstFrame) {
try {
logger.info("开始使用真实API处理图生视频任务: {}", task.getTaskId());
// 重新从数据库加载任务,确保获取最新状态
ImageToVideoTask currentTask = taskRepository.findByTaskId(task.getTaskId())
.orElseThrow(() -> new RuntimeException("任务不存在: " + task.getTaskId()));
// 检查任务是否已经有 realTaskId如果有说明已经提交过了不应该再次处理
if (currentTask.getRealTaskId() != null && !currentTask.getRealTaskId().isEmpty()) {
logger.warn("图生视频任务 {} 已经有 realTaskId{}),说明已经提交过了,跳过处理",
task.getTaskId(), currentTask.getRealTaskId());
return CompletableFuture.completedFuture(null);
}
// 检查任务状态如果已经不是PENDING说明已经被其他线程处理了
if (currentTask.getStatus() != ImageToVideoTask.TaskStatus.PENDING) {
logger.warn("图生视频任务 {} 状态已不是PENDING当前状态: {}),跳过处理,可能已被其他线程处理",
task.getTaskId(), currentTask.getStatus());
return CompletableFuture.completedFuture(null);
}
// 更新任务状态为处理中
task.updateStatus(ImageToVideoTask.TaskStatus.PROCESSING);
taskRepository.save(task);
currentTask.updateStatus(ImageToVideoTask.TaskStatus.PROCESSING);
taskRepository.save(currentTask);
// 将图片转换为Base64
String imageBase64 = realAIService.convertImageToBase64(
@@ -182,11 +176,11 @@ public class ImageToVideoService {
// 调用真实API提交任务
Map<String, Object> apiResponse = realAIService.submitImageToVideoTask(
task.getPrompt(),
currentTask.getPrompt(),
imageBase64,
task.getAspectRatio(),
task.getDuration().toString(),
task.getHdMode()
currentTask.getAspectRatio(),
currentTask.getDuration().toString(),
currentTask.getHdMode()
);
// 从API响应中提取真实任务ID
@@ -223,20 +217,25 @@ public class ImageToVideoService {
// 如果找到了真实任务ID保存到数据库
if (realTaskId != null) {
task.setRealTaskId(realTaskId);
taskRepository.save(task);
// 重新加载任务以确保获取最新状态
currentTask = taskRepository.findByTaskId(task.getTaskId())
.orElseThrow(() -> new RuntimeException("任务不存在: " + task.getTaskId()));
currentTask.setRealTaskId(realTaskId);
taskRepository.save(currentTask);
logger.info("真实任务ID已保存: {} -> {}", task.getTaskId(), realTaskId);
} else {
// 如果没有找到任务ID说明任务提交失败
logger.error("任务提交失败未从API响应中获取到任务ID");
task.updateStatus(ImageToVideoTask.TaskStatus.FAILED);
task.setErrorMessage("任务提交失败API未返回有效的任务ID");
taskRepository.save(task);
currentTask = taskRepository.findByTaskId(task.getTaskId())
.orElseThrow(() -> new RuntimeException("任务不存在: " + task.getTaskId()));
currentTask.updateStatus(ImageToVideoTask.TaskStatus.FAILED);
currentTask.setErrorMessage("任务提交失败API未返回有效的任务ID");
taskRepository.save(currentTask);
return CompletableFuture.completedFuture(null); // 直接返回,不进行轮询
}
// 开始轮询真实任务状态
pollRealTaskStatus(task);
pollRealTaskStatus(currentTask);
} catch (Exception e) {
logger.error("使用真实API处理图生视频任务失败: {}", task.getTaskId(), e);
@@ -463,4 +462,64 @@ public class ImageToVideoService {
LocalDateTime expiredDate = LocalDateTime.now().minusDays(30);
return taskRepository.deleteExpiredTasks(expiredDate);
}
/**
* 检查并处理超时的图生视频任务
* 如果任务状态为PROCESSING且创建时间超过10分钟则标记为超时
* 注意如果任务已经有resultUrl视频已生成即使超时也不标记为失败因为视频已经成功生成
*/
@Transactional
public int checkAndHandleTimeoutTasks() {
try {
// 计算超时时间点10分钟前
LocalDateTime timeoutTime = LocalDateTime.now().minusMinutes(10);
// 查找超时的任务状态为PROCESSING创建时间超过10分钟
List<ImageToVideoTask> timeoutTasks = taskRepository.findTimeoutTasks(
ImageToVideoTask.TaskStatus.PROCESSING,
timeoutTime
);
if (timeoutTasks.isEmpty()) {
return 0;
}
logger.warn("发现 {} 个可能超时的图生视频任务,开始检查", timeoutTasks.size());
int handledCount = 0;
int skippedCount = 0;
for (ImageToVideoTask task : timeoutTasks) {
try {
// 检查任务是否已经有resultUrl视频已生成
// 如果有resultUrl说明视频已经成功生成不应该被标记为超时失败
if (task.getResultUrl() != null && !task.getResultUrl().isEmpty()) {
logger.debug("任务 {} 已有resultUrl视频已生成跳过超时标记", task.getTaskId());
skippedCount++;
continue;
}
// 更新任务状态为失败
task.updateStatus(ImageToVideoTask.TaskStatus.FAILED);
task.setErrorMessage("图生视频任务超时任务创建后超过10分钟仍未完成");
taskRepository.save(task);
logger.warn("图生视频任务超时,已标记为失败: taskId={}", task.getTaskId());
handledCount++;
} catch (Exception e) {
logger.error("处理超时图生视频任务失败: taskId={}", task.getTaskId(), e);
}
}
if (handledCount > 0 || skippedCount > 0) {
logger.info("处理超时图生视频任务完成,失败: {}/{},跳过(已生成): {}",
handledCount, timeoutTasks.size(), skippedCount);
}
return handledCount;
} catch (Exception e) {
logger.error("检查超时图生视频任务失败", e);
return 0;
}
}
}

View File

@@ -6,7 +6,6 @@ import java.util.List;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.scheduling.annotation.Scheduled;
import org.springframework.stereotype.Service;
import com.example.demo.model.TaskQueue;
@@ -14,7 +13,8 @@ import com.example.demo.repository.TaskQueueRepository;
/**
* 轮询查询服务
* 每2分钟执行一次查询任务状态
* 注意自动调度已禁用避免与TaskQueueScheduler重复查询
* 保留此服务用于手动调用或测试
*/
@Service
public class PollingQueryService {
@@ -28,11 +28,11 @@ public class PollingQueryService {
private TaskQueueRepository taskQueueRepository;
/**
* 每2分钟执行一次轮询查询
* 固定间隔120000毫秒 = 2分钟
* 查询所有正在处理的任务状态
* 执行轮询查询已禁用自动调度避免与TaskQueueScheduler重复
* 保留此方法用于手动调用或测试
* 注意TaskQueueScheduler.checkTaskStatuses() 已经统一管理任务状态检查
*/
@Scheduled(fixedRate = 120000) // 2分钟 = 120000毫秒
// @Scheduled(fixedRate = 120000) // 已禁用避免与TaskQueueScheduler重复查询
public void executePollingQuery() {
logger.info("=== 开始执行轮询查询 (每2分钟) ===");
logger.info("轮询查询时间: {}", LocalDateTime.now());

View File

@@ -42,98 +42,364 @@ public class RealAIService {
public RealAIService() {
this.objectMapper = new ObjectMapper();
// 设置Unirest超时
Unirest.config().connectTimeout(0).socketTimeout(0);
// 配置ObjectMapper保留null值API可能期望null而不是省略字段
// objectMapper.setSerializationInclusion(com.fasterxml.jackson.annotation.JsonInclude.Include.ALWAYS);
// 默认行为是包含null值这是正确的
// 设置Unirest超时 - 参考Comfly实现连接超时30秒读取超时5分钟300秒匹配Python requests timeout=300
// 禁用Apache HttpClient的内部重试使用我们自己的重试机制
Unirest.config()
.connectTimeout(30000) // 30秒连接超时
.socketTimeout(300000) // 5分钟读取超时300秒匹配参考代码
.retryAfter(false) // 禁用自动重试,使用我们自己的重试逻辑
.httpClient(org.apache.http.impl.client.HttpClients.custom()
.setRetryHandler(new org.apache.http.impl.client.DefaultHttpRequestRetryHandler(0, false)) // 禁用内部重试
.setMaxConnTotal(500)
.setMaxConnPerRoute(100)
.setConnectionTimeToLive(30, java.util.concurrent.TimeUnit.SECONDS)
.evictExpiredConnections()
.evictIdleConnections(30, java.util.concurrent.TimeUnit.SECONDS)
// 配置请求和响应缓冲区
.setDefaultRequestConfig(org.apache.http.client.config.RequestConfig.custom()
.setConnectTimeout(30000)
.setSocketTimeout(300000)
.setConnectionRequestTimeout(30000)
.setContentCompressionEnabled(false) // 禁用压缩,避免额外开销
.build())
.build());
}
/**
* 提交图生视频任务
* 提交图生视频任务多张图片参考sora2实现
* 参考Comfly.py 6285-6292行使用images数组
*/
public Map<String, Object> submitStoryboardVideoTask(String prompt, List<String> images,
String aspectRatio, String duration,
boolean hdMode) {
int maxRetries = 3;
int retryCount = 0;
long baseDelayMs = 5000; // 基础延迟5秒
while (retryCount < maxRetries) {
try {
// 根据参数选择可用的模型使用sora2模型
String modelName = selectTextToVideoModel(aspectRatio, duration, hdMode);
// 验证图片格式参考sora2实现确保每张图片都有data URI前缀
List<String> validatedImages = validateImageFormat(images);
// 使用 Sora2 端点参考Comfly.py 6297行
String url = aiApiBaseUrl + "/v2/videos/generations";
// 使用 Sora2 API 的请求格式参考Comfly.py 6285-6292行
Map<String, Object> requestMap = new HashMap<>();
requestMap.put("prompt", prompt);
requestMap.put("model", modelName);
requestMap.put("images", validatedImages); // 使用images数组参考sora2实现
requestMap.put("aspect_ratio", aspectRatio);
requestMap.put("duration", duration);
requestMap.put("hd", hdMode);
String requestBody;
try {
requestBody = objectMapper.writeValueAsString(requestMap);
} catch (Exception e) {
logger.error("构建JSON请求体失败", e);
throw new RuntimeException("构建请求体失败: " + e.getMessage(), e);
}
// 记录请求体大小
long requestBodySize = requestBody.getBytes(java.nio.charset.StandardCharsets.UTF_8).length;
long requestBodySizeMB = requestBodySize / (1024 * 1024);
long requestBodySizeKB = requestBodySize / 1024;
if (retryCount > 0) {
logger.info("分镜视频请求重试 (第{}次,共{}次): URL={}, 请求体大小={}KB ({}MB), model={}, 图片数量={}",
retryCount + 1, maxRetries, url, requestBodySizeKB, requestBodySizeMB, modelName, validatedImages.size());
} else {
logger.info("分镜视频请求: URL={}, 请求体大小={}KB ({}MB, {}字节), model={}, aspectRatio={}, duration={}, 图片数量={}",
url, requestBodySizeKB, requestBodySizeMB, requestBodySize, modelName, aspectRatio, duration, validatedImages.size());
}
HttpResponse<String> response = Unirest.post(url)
.header("Authorization", "Bearer " + aiApiKey)
.header("Content-Type", "application/json; charset=UTF-8")
.header("Accept", "application/json")
.header("Connection", "keep-alive")
.body(requestBody)
.asString();
logger.info("API响应状态: {}", response.getStatus());
String responseBodyStr = response.getBody();
logger.info("API响应内容前500字符: {}", responseBodyStr != null && responseBodyStr.length() > 500 ?
responseBodyStr.substring(0, 500) : responseBodyStr);
if (response.getStatus() == 200 && responseBodyStr != null) {
String trimmedResponse = responseBodyStr.trim();
String lowerResponse = trimmedResponse.toLowerCase();
if (lowerResponse.startsWith("<!") || lowerResponse.startsWith("<html") ||
lowerResponse.contains("<!doctype") || (!trimmedResponse.startsWith("{") && !trimmedResponse.startsWith("["))) {
logger.error("API返回HTML页面而不是JSON可能是认证失败或API端点错误");
throw new RuntimeException("API返回HTML页面可能是认证失败。请检查API密钥和端点配置");
}
try {
@SuppressWarnings("unchecked")
Map<String, Object> responseBody = objectMapper.readValue(responseBodyStr, Map.class);
// Sora2 API 使用 task_id 字段表示成功
if (responseBody.containsKey("task_id")) {
logger.info("分镜视频任务提交成功task_id: {}", responseBody.get("task_id"));
Map<String, Object> result = new HashMap<>();
result.put("code", 200);
result.put("data", responseBody);
result.put("task_id", responseBody.get("task_id"));
return result;
} else {
// 处理错误响应
String errorMsg = responseBody.containsKey("message") ?
responseBody.get("message").toString() : "未知错误";
logger.error("分镜视频任务提交失败: {}", errorMsg);
throw new RuntimeException("API返回错误: " + errorMsg);
}
} catch (com.fasterxml.jackson.core.JsonProcessingException e) {
logger.error("解析API响应JSON失败: {}", responseBodyStr, e);
throw new RuntimeException("解析API响应失败: " + e.getMessage());
}
} else {
String errorMsg = String.format("API请求失败HTTP状态: %d, 响应: %s",
response.getStatus(), responseBodyStr);
logger.error(errorMsg);
throw new RuntimeException(errorMsg);
}
} catch (Exception e) {
retryCount++;
if (retryCount >= maxRetries) {
logger.error("分镜视频任务提交失败,已重试{}次: {}", maxRetries, e.getMessage(), e);
throw new RuntimeException("分镜视频任务提交失败: " + e.getMessage(), e);
}
long delayMs = baseDelayMs * retryCount;
logger.warn("分镜视频任务提交失败,{}秒后重试 (第{}次,共{}次): {}",
delayMs / 1000, retryCount, maxRetries, e.getMessage());
try {
Thread.sleep(delayMs);
} catch (InterruptedException ie) {
Thread.currentThread().interrupt();
throw new RuntimeException("重试等待被中断", ie);
}
}
}
throw new RuntimeException("分镜视频任务提交失败,已重试" + maxRetries + "");
}
/**
* 提交图生视频任务(带重试机制)
*/
public Map<String, Object> submitImageToVideoTask(String prompt, String imageBase64,
String aspectRatio, String duration,
boolean hdMode) {
try {
// 根据参数选择可用的模型
String modelName = selectAvailableImageToVideoModel(aspectRatio, duration, hdMode);
// 将Base64图片转换为字节数组
String base64Data = imageBase64;
if (imageBase64.contains(",")) {
base64Data = imageBase64.substring(imageBase64.indexOf(",") + 1);
}
// 验证base64数据格式
int maxRetries = 3;
int retryCount = 0;
long baseDelayMs = 5000; // 基础延迟5秒给服务器更多恢复时间
int retryAttempt = 0; // 重试次数计数器
while (retryCount < maxRetries) {
try {
Base64.getDecoder().decode(base64Data);
logger.debug("Base64数据格式验证通过");
} catch (IllegalArgumentException e) {
logger.error("Base64数据格式错误: {}", e.getMessage());
throw new RuntimeException("图片数据格式错误");
}
// 根据分辨率选择size参数用于日志记录
String size = convertAspectRatioToSize(aspectRatio, hdMode);
logger.debug("选择的尺寸参数: {}", size);
String url = aiApiBaseUrl + "/user/ai/tasks/submit";
String requestBody = String.format("{\"modelName\":\"%s\",\"prompt\":\"%s\",\"aspectRatio\":\"%s\",\"imageToVideo\":true,\"imageBase64\":\"%s\"}",
modelName, prompt, aspectRatio, imageBase64);
logger.info("图生视频请求体: {}", requestBody);
HttpResponse<String> response = Unirest.post(url)
.header("Authorization", "Bearer " + aiApiKey)
.header("Content-Type", "application/json")
.body(requestBody)
.asString();
// 添加响应调试日志
logger.info("API响应状态: {}", response.getStatus());
String responseBodyStr = response.getBody();
logger.info("API响应内容前500字符: {}", responseBodyStr != null && responseBodyStr.length() > 500 ?
responseBodyStr.substring(0, 500) : responseBodyStr);
if (response.getStatus() == 200 && responseBodyStr != null) {
// 检查响应是否为HTML可能是认证失败或API端点错误
String trimmedResponse = responseBodyStr.trim();
String lowerResponse = trimmedResponse.toLowerCase();
if (lowerResponse.startsWith("<!") || lowerResponse.startsWith("<html") ||
lowerResponse.contains("<!doctype") || (!trimmedResponse.startsWith("{") && !trimmedResponse.startsWith("["))) {
logger.error("API返回HTML页面而不是JSON可能是认证失败或API端点错误");
logger.error("响应前100字符: {}", trimmedResponse.length() > 100 ? trimmedResponse.substring(0, 100) : trimmedResponse);
logger.error("请检查1) API密钥是否正确 2) API端点URL是否正确 3) API服务是否正常运行");
throw new RuntimeException("API返回HTML页面可能是认证失败。请检查API密钥和端点配置");
// 根据参数选择可用的模型
String modelName = selectAvailableImageToVideoModel(aspectRatio, duration, hdMode);
// 验证base64数据格式提取纯Base64数据用于验证
String base64DataForValidation = imageBase64;
if (imageBase64.contains(",")) {
base64DataForValidation = imageBase64.substring(imageBase64.indexOf(",") + 1);
}
try {
Base64.getDecoder().decode(base64DataForValidation);
logger.debug("Base64数据格式验证通过");
} catch (IllegalArgumentException e) {
logger.error("Base64数据格式错误: {}", e.getMessage());
throw new RuntimeException("图片数据格式错误");
}
// 根据分辨率选择size参数用于日志记录
String size = convertAspectRatioToSize(aspectRatio, hdMode);
logger.debug("选择的尺寸参数: {}", size);
// 使用 Sora2 端点(与文生视频使用相同的端点,参考 Comfly.py 6297 行)
String url = aiApiBaseUrl + "/v2/videos/generations";
// 使用 Sora2 API 的请求格式(参考 Comfly.py 6285-6292行
// 图生视频使用 images 数组(即使只有一张图片)
// 验证并规范化图片格式参考sora2实现
List<String> imagesList = new java.util.ArrayList<>();
imagesList.add(imageBase64);
imagesList = validateImageFormat(imagesList);
Map<String, Object> requestMap = new HashMap<>();
requestMap.put("prompt", prompt);
requestMap.put("model", modelName);
requestMap.put("images", imagesList); // 使用 images 数组,不是单个 image
requestMap.put("aspect_ratio", aspectRatio);
requestMap.put("duration", duration);
requestMap.put("hd", hdMode);
String requestBody;
try {
@SuppressWarnings("unchecked")
Map<String, Object> responseBody = objectMapper.readValue(responseBodyStr, Map.class);
Integer code = (Integer) responseBody.get("code");
if (code != null && code == 200) {
logger.info("图生视频任务提交成功: {}", responseBody);
return responseBody;
} else {
logger.error("图生视频任务提交失败: {}", responseBody);
throw new RuntimeException("任务提交失败: " + responseBody.get("message"));
}
} catch (com.fasterxml.jackson.core.JsonParseException e) {
logger.error("解析API响应为JSON失败响应内容可能是HTML或其他格式", e);
logger.error("响应内容前200字符: {}", responseBodyStr.length() > 200 ?
responseBodyStr.substring(0, 200) : responseBodyStr);
throw new RuntimeException("API返回非JSON响应可能是认证失败。请检查API密钥和端点配置");
requestBody = objectMapper.writeValueAsString(requestMap);
} catch (Exception e) {
logger.error("构建JSON请求体失败", e);
throw new RuntimeException("构建请求体失败: " + e.getMessage(), e);
}
} else {
logger.error("图生视频任务提交失败HTTP状态: {}", response.getStatus());
throw new RuntimeException("任务提交失败HTTP状态: " + response.getStatus());
}
// 记录请求体大小Base64编码后的图片可能很大
long requestBodySize = requestBody.getBytes(java.nio.charset.StandardCharsets.UTF_8).length;
long requestBodySizeMB = requestBodySize / (1024 * 1024);
long requestBodySizeKB = requestBodySize / 1024;
if (retryCount > 0) {
logger.info("图生视频请求重试 (第{}次,共{}次): URL={}, 请求体大小={}KB ({}MB), model={}",
retryCount + 1, maxRetries, url, requestBodySizeKB, requestBodySizeMB, modelName);
} else {
logger.info("图生视频请求: URL={}, 请求体大小={}KB ({}MB, {}字节), model={}, aspectRatio={}, duration={}",
url, requestBodySizeKB, requestBodySizeMB, requestBodySize, modelName, aspectRatio, duration);
// 如果请求体太大只记录前500字符
if (requestBody.length() > 500) {
logger.debug("请求体前500字符: {}", requestBody.substring(0, 500));
}
}
// 使用流式传输,避免一次性加载整个请求体到内存
// 添加额外的请求头以支持大请求体
HttpResponse<String> response = Unirest.post(url)
.header("Authorization", "Bearer " + aiApiKey)
.header("Content-Type", "application/json; charset=UTF-8")
.header("Accept", "application/json")
.header("Connection", "keep-alive")
.body(requestBody)
.asString();
} catch (UnirestException e) {
logger.error("提交图生视频任务异常", e);
throw new RuntimeException("提交任务失败: " + e.getMessage());
} catch (Exception e) {
logger.error("提交图生视频任务异常", e);
throw new RuntimeException("提交任务失败: " + e.getMessage());
// 添加响应调试日志
logger.info("API响应状态: {}", response.getStatus());
String responseBodyStr = response.getBody();
logger.info("API响应内容前500字符: {}", responseBodyStr != null && responseBodyStr.length() > 500 ?
responseBodyStr.substring(0, 500) : responseBodyStr);
if (response.getStatus() == 200 && responseBodyStr != null) {
// 检查响应是否为HTML可能是认证失败或API端点错误
String trimmedResponse = responseBodyStr.trim();
String lowerResponse = trimmedResponse.toLowerCase();
if (lowerResponse.startsWith("<!") || lowerResponse.startsWith("<html") ||
lowerResponse.contains("<!doctype") || (!trimmedResponse.startsWith("{") && !trimmedResponse.startsWith("["))) {
logger.error("API返回HTML页面而不是JSON可能是认证失败或API端点错误");
logger.error("响应前100字符: {}", trimmedResponse.length() > 100 ? trimmedResponse.substring(0, 100) : trimmedResponse);
logger.error("请检查1) API密钥是否正确 2) API端点URL是否正确 3) API服务是否正常运行");
throw new RuntimeException("API返回HTML页面可能是认证失败。请检查API密钥和端点配置");
}
try {
@SuppressWarnings("unchecked")
Map<String, Object> responseBody = objectMapper.readValue(responseBodyStr, Map.class);
// Sora2 API 使用 task_id 字段表示成功(与文生视频相同格式)
if (responseBody.containsKey("task_id")) {
logger.info("图生视频任务提交成功task_id: {}", responseBody.get("task_id"));
// 转换为统一的响应格式(与文生视频保持一致)
Map<String, Object> result = new HashMap<>();
result.put("code", 200);
result.put("data", responseBody);
result.put("task_id", responseBody.get("task_id"));
return result;
} else {
// 处理错误响应
logger.error("图生视频任务提交失败响应中缺少task_id: {}", responseBody);
String errorMsg = "未知错误";
if (responseBody.get("message") != null) {
errorMsg = responseBody.get("message").toString();
}
throw new RuntimeException("任务提交失败: " + errorMsg);
}
} catch (com.fasterxml.jackson.core.JsonParseException e) {
logger.error("解析API响应为JSON失败响应内容可能是HTML或其他格式", e);
logger.error("响应内容前200字符: {}", responseBodyStr.length() > 200 ?
responseBodyStr.substring(0, 200) : responseBodyStr);
throw new RuntimeException("API返回非JSON响应可能是认证失败。请检查API密钥和端点配置");
}
} else {
logger.error("图生视频任务提交失败HTTP状态: {}", response.getStatus());
throw new RuntimeException("任务提交失败HTTP状态: " + response.getStatus());
}
} catch (UnirestException e) {
retryCount++;
retryAttempt++;
Throwable cause = e.getCause();
String errorMessage = e.getMessage();
// 详细记录错误信息用于诊断
logger.warn("UnirestException详情: message={}, cause={}, causeClass={}",
errorMessage,
cause != null ? cause.getMessage() : "null",
cause != null ? cause.getClass().getName() : "null");
// 判断是否为可重试的错误
boolean isRetryable = false;
if (cause != null) {
String causeMessage = cause.getMessage();
// Connection reset, connection refused, timeout 等可重试错误
if (cause instanceof java.net.SocketException) {
isRetryable = true;
logger.warn("网络连接错误 (可重试): {}", causeMessage);
// 如果是 Connection reset可能是服务器端问题或请求体过大
if (causeMessage != null && causeMessage.contains("Connection reset")) {
logger.warn("Connection reset 可能原因: 1) 服务器端限制请求体大小 2) 服务器端超时 3) 网络不稳定");
}
} else if (cause instanceof java.net.ConnectException) {
isRetryable = true;
logger.warn("连接被拒绝 (可重试): {}", causeMessage);
} else if (cause instanceof java.net.SocketTimeoutException) {
isRetryable = true;
logger.warn("连接超时 (可重试): {}", causeMessage);
} else if (errorMessage != null && (
errorMessage.contains("Connection reset") ||
errorMessage.contains("Connection refused") ||
errorMessage.contains("timeout") ||
errorMessage.contains("SocketException"))) {
isRetryable = true;
logger.warn("网络错误 (可重试): {}", errorMessage);
}
}
if (isRetryable && retryCount < maxRetries) {
// 指数退避第1次重试等待5秒第2次等待10秒第3次等待20秒
long delayMs = baseDelayMs * (1L << (retryAttempt - 1));
logger.warn("提交图生视频任务失败 (第{}次尝试,共{}次){} 秒后重试... 错误: {}",
retryCount, maxRetries, delayMs / 1000, errorMessage);
logger.warn("提示: 如果请求体较大,可能是网络传输时被服务器重置连接,重试时将等待更长时间");
try {
Thread.sleep(delayMs);
} catch (InterruptedException ie) {
Thread.currentThread().interrupt();
throw new RuntimeException("重试等待被中断", ie);
}
continue; // 重试
} else {
logger.error("提交图生视频任务异常 (已重试{}次): {}", retryCount, errorMessage, e);
logger.error("建议: 1) 检查网络连接 2) 检查请求体大小是否过大 3) 联系API服务提供商检查服务器状态");
throw new RuntimeException("提交任务失败: " + errorMessage, e);
}
} catch (Exception e) {
// 其他不可重试的错误直接抛出
logger.error("提交图生视频任务异常", e);
throw new RuntimeException("提交任务失败: " + e.getMessage(), e);
}
}
// 理论上不会到达这里,但为了编译通过
throw new RuntimeException("提交图生视频任务失败,已重试" + maxRetries + "");
}
/**
@@ -142,6 +408,7 @@ public class RealAIService {
public Map<String, Object> submitTextToVideoTask(String prompt, String aspectRatio,
String duration, boolean hdMode) {
try {
// 根据参考代码文生视频支持5秒和10秒duration直接传入不转换
// 根据参数选择可用的模型
String modelName = selectAvailableTextToVideoModel(aspectRatio, duration, hdMode);
@@ -159,7 +426,7 @@ public class RealAIService {
requestBodyMap.put("prompt", prompt);
requestBodyMap.put("model", modelName);
requestBodyMap.put("aspect_ratio", aspectRatio);
requestBodyMap.put("duration", duration);
requestBodyMap.put("duration", duration); // duration直接传入不转换
requestBodyMap.put("hd", hdMode);
String requestBody = objectMapper.writeValueAsString(requestBodyMap);
@@ -322,15 +589,17 @@ public class RealAIService {
/**
* 根据参数选择图生视频模型(默认逻辑)
* 使用 Sora2 模型,与文生视频使用相同的模型选择逻辑
*/
private String selectImageToVideoModel(String aspectRatio, String duration, boolean hdMode) {
String size = hdMode ? "large" : "small";
String orientation = "9:16".equals(aspectRatio) || "3:4".equals(aspectRatio) ? "portrait" : "landscape";
// 根据API返回的模型列表只支持10s和15s
String actualDuration = "5".equals(duration) ? "10" : duration;
return String.format("sc_sora2_img_%s_%ss_%s", orientation, actualDuration, size);
// 使用 Sora2 模型,与文生视频相同
// - sora-2: 支持10s和15s不支持25s和HD
// - sora-2-pro: 支持10s、15s和25s支持HD
if ("25".equals(duration) || hdMode) {
return "sora-2-pro";
}
// aspectRatio参数未使用但保留以保持方法签名一致
return "sora-2";
}
/**
@@ -409,6 +678,57 @@ public class RealAIService {
return "sora-2";
}
/**
* 验证并规范化图片格式参考sora2实现
* 确保所有图片都是Base64格式带data URI前缀
* 支持PNG和JPEG格式JPEG用于压缩后的图片
* 参考Comfly.py 6236行data:image/png;base64,{base64_str} 或 data:image/jpeg;base64,{base64_str}
*/
private List<String> validateImageFormat(List<String> images) {
List<String> validatedImages = new java.util.ArrayList<>();
for (String img : images) {
if (img == null || img.isEmpty()) {
continue;
}
String validatedImg = img;
// 确保有data URI前缀参考Comfly.py 6236行
if (!img.startsWith("data:")) {
// 如果没有前缀,添加前缀
if (img.contains(",")) {
// 如果已经有逗号提取Base64数据并添加前缀
String pureBase64Data = img.substring(img.indexOf(",") + 1);
// 默认使用PNG格式如果无法确定格式
validatedImg = "data:image/png;base64," + pureBase64Data;
} else {
// 假设是纯Base64数据
validatedImg = "data:image/png;base64," + img;
}
} else {
// 如果已经有data URI前缀检查格式
// 支持 data:image/png;base64, 和 data:image/jpeg;base64,
if (!img.startsWith("data:image/png;base64,") &&
!img.startsWith("data:image/jpeg;base64,") &&
!img.startsWith("data:image/jpg;base64,")) {
// 如果不是标准格式,尝试修复
if (img.contains(",")) {
String pureBase64Data = img.substring(img.indexOf(",") + 1);
// 保持原有格式或默认使用PNG
if (img.contains("jpeg") || img.contains("jpg")) {
validatedImg = "data:image/jpeg;base64," + pureBase64Data;
} else {
validatedImg = "data:image/png;base64," + pureBase64Data;
}
}
}
}
validatedImages.add(validatedImg);
}
logger.debug("验证图片格式完成,原始数量: {}, 验证后数量: {}", images.size(), validatedImages.size());
return validatedImages;
}
/**
* 将图片文件转换为Base64
*/

View File

@@ -1,28 +1,46 @@
package com.example.demo.service;
import java.awt.Graphics2D;
import java.awt.RenderingHints;
import java.awt.image.BufferedImage;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.net.HttpURLConnection;
import java.net.URI;
import java.net.URL;
import java.time.LocalDateTime;
import java.util.ArrayList;
import java.util.Base64;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import javax.imageio.ImageIO;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.ApplicationContext;
import org.springframework.data.domain.Page;
import org.springframework.data.domain.PageRequest;
import org.springframework.data.domain.Pageable;
import org.springframework.scheduling.annotation.Async;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Propagation;
import org.springframework.transaction.annotation.Transactional;
import org.springframework.transaction.support.TransactionSynchronization;
import org.springframework.transaction.support.TransactionSynchronizationManager;
import com.example.demo.model.StoryboardVideoTask;
import com.example.demo.repository.StoryboardVideoTaskRepository;
import com.fasterxml.jackson.databind.ObjectMapper;
/**
* 分镜视频服务类
* 注意:不在类级别使用 @Transactional因为某些方法需要禁用事务如长时间运行的外部API调用
*/
@Service
@Transactional
public class StoryboardVideoService {
private static final Logger logger = LoggerFactory.getLogger(StoryboardVideoService.class);
@@ -35,143 +53,194 @@ public class StoryboardVideoService {
@Autowired
private ImageGridService imageGridService;
@Autowired
private TaskQueueService taskQueueService;
@Autowired
private ApplicationContext applicationContext;
@Autowired
private org.springframework.transaction.support.TransactionTemplate asyncTransactionTemplate;
@Autowired
private org.springframework.transaction.support.TransactionTemplate readOnlyTransactionTemplate;
// 默认生成6张分镜图
private static final int DEFAULT_STORYBOARD_IMAGES = 6;
private final ObjectMapper objectMapper = new ObjectMapper();
/**
* 创建分镜视频任务
* 注意:使用 REQUIRES_NEW 确保事务快速提交,避免长时间占用连接
* 事务提交后,异步方法在事务外执行
*/
@Transactional(propagation = Propagation.REQUIRES_NEW)
public StoryboardVideoTask createTask(String username, String prompt, String aspectRatio, boolean hdMode, String imageUrl) {
try {
// 验证参数
if (username == null || username.trim().isEmpty()) {
throw new IllegalArgumentException("用户名不能为空");
}
if (prompt == null || prompt.trim().isEmpty()) {
throw new IllegalArgumentException("文本描述不能为空");
}
// 生成任务ID
String taskId = generateTaskId();
// 创建任务
StoryboardVideoTask task = new StoryboardVideoTask(username, prompt.trim(), aspectRatio, hdMode);
task.setTaskId(taskId);
task.setStatus(StoryboardVideoTask.TaskStatus.PENDING);
task.setProgress(0);
if (imageUrl != null && !imageUrl.isEmpty()) {
task.setImageUrl(imageUrl);
}
// 保存任务
task = taskRepository.save(task);
logger.info("分镜视频任务创建成功: {}, 用户: {}", taskId, username);
// 异步处理任务
processTaskAsync(taskId);
return task;
} catch (Exception e) {
logger.error("创建分镜视频任务失败", e);
throw new RuntimeException("创建任务失败: " + e.getMessage());
// 验证参数
if (username == null || username.trim().isEmpty()) {
throw new IllegalArgumentException("用户名不能为空");
}
if (prompt == null || prompt.trim().isEmpty()) {
throw new IllegalArgumentException("文本描述不能为空");
}
// 生成任务ID
String taskId = generateTaskId();
// 创建任务
StoryboardVideoTask task = new StoryboardVideoTask(username, prompt.trim(), aspectRatio, hdMode);
task.setTaskId(taskId);
task.setStatus(StoryboardVideoTask.TaskStatus.PENDING);
task.setProgress(0);
if (imageUrl != null && !imageUrl.isEmpty()) {
task.setImageUrl(imageUrl);
}
// 保存任务(快速完成,事务立即提交)
task = taskRepository.save(task);
logger.info("分镜视频任务创建成功: {}, 用户: {}", taskId, username);
// 注意:异步方法调用必须在事务提交后执行,避免占用连接
// 使用 TransactionSynchronizationManager 确保在事务提交后再调用异步方法
// 通过 ApplicationContext 获取代理对象,确保 @Async 生效
// 注意:获取代理对象的操作在事务内,但这是轻量级操作,不会长时间占用连接
final String finalTaskId = taskId;
final StoryboardVideoService self = applicationContext.getBean(StoryboardVideoService.class);
TransactionSynchronizationManager.registerSynchronization(new TransactionSynchronization() {
@Override
public void afterCommit() {
// 事务提交后,在事务外执行异步方法
// 使用代理对象调用,确保 @Async 生效
self.processTaskAsync(finalTaskId);
}
});
// 方法立即返回,事务快速提交
return task;
}
/**
* 使用真实API处理任务异步
* 使用Spring的@Async注解自动管理事务边界
* 注意此方法明确禁用事务因为长时间运行的外部API调用会占用数据库连接
* 只在需要数据库操作时使用单独的事务方法
*/
@Async
@Transactional
@Async("taskExecutor")
@Transactional(propagation = Propagation.NOT_SUPPORTED)
public void processTaskAsync(String taskId) {
try {
logger.info("开始使用真实API处理分镜视频任务: {}", taskId);
// 重新从数据库加载任务,获取最新状态
StoryboardVideoTask task = taskRepository.findByTaskId(taskId)
.orElseThrow(() -> new RuntimeException("任务未找到: " + taskId));
// 在异步方法中使用 TransactionTemplate 手动管理事务,确保事务正确关闭
StoryboardVideoTask taskInfo = loadTaskInfoWithTransactionTemplate(taskId);
String prompt = taskInfo.getPrompt();
String aspectRatio = taskInfo.getAspectRatio();
boolean hdMode = taskInfo.isHdMode();
// 更新任务状态为处理中
task.updateStatus(StoryboardVideoTask.TaskStatus.PROCESSING);
taskRepository.flush(); // 强制刷新到数据库
// 更新任务状态为处理中(使用 TransactionTemplate 确保事务正确关闭)
updateTaskStatusWithTransactionTemplate(taskId);
// 调用真实文生图API生成多张分镜图
// 参考Comfly项目如果API不支持一次生成多张图片则多次调用生成多张
logger.info("分镜视频任务已提交正在调用文生图API生成{}张分镜图...", DEFAULT_STORYBOARD_IMAGES);
logger.info("开始生成{}张分镜图...", DEFAULT_STORYBOARD_IMAGES);
// 收集所有图片URL
List<String> imageUrls = new ArrayList<>();
// 参考Comfly项目多次调用API生成多张图片因为Comfly API可能不支持一次生成多张
// 添加重试机制,提高成功率
int maxRetriesPerImage = 2; // 每张图片最多重试2次
long startTime = System.currentTimeMillis();
for (int i = 0; i < DEFAULT_STORYBOARD_IMAGES; i++) {
try {
logger.info("生成第{}张分镜图(共{}张)...", i + 1, DEFAULT_STORYBOARD_IMAGES);
// 每次调用生成1张图片使用banana模型
Map<String, Object> apiResponse = realAIService.submitTextToImageTask(
task.getPrompt(),
task.getAspectRatio(),
1, // 每次生成1张图片
task.isHdMode() // 使用任务的hdMode参数选择模型
);
// 检查API响应是否为空
if (apiResponse == null) {
logger.warn("第{}张图片API响应为null跳过", i + 1);
continue;
}
// 从API响应中提取图片URL
// 参考Comfly_nano_banana_edit节点响应格式为 {"data": [{"url": "...", "b64_json": "..."}]}
Object dataObj = apiResponse.get("data");
if (dataObj instanceof List) {
@SuppressWarnings("unchecked")
List<Map<String, Object>> data = (List<Map<String, Object>>) dataObj;
if (!data.isEmpty()) {
// 提取第一张图片的URL因为每次只生成1张
Map<String, Object> imageData = data.get(0);
if (imageData == null) {
logger.warn("第{}张图片data第一个元素为null跳过", i + 1);
continue;
}
String imageUrl = null;
Object urlObj = imageData.get("url");
Object b64JsonObj = imageData.get("b64_json");
if (urlObj != null) {
imageUrl = urlObj.toString();
} else if (b64JsonObj != null) {
// base64编码的图片
String base64Data = b64JsonObj.toString();
imageUrl = "data:image/png;base64," + base64Data;
}
if (imageUrl != null && !imageUrl.isEmpty()) {
imageUrls.add(imageUrl);
logger.info("成功获取第{}张分镜图", i + 1);
boolean imageGenerated = false;
int retryCount = 0;
// 重试机制如果单张图片生成失败重试最多2次
while (!imageGenerated && retryCount <= maxRetriesPerImage) {
try {
if (retryCount > 0) {
logger.info("重试生成第{}张分镜图(第{}次重试)...", i + 1, retryCount);
Thread.sleep(1000 * retryCount); // 重试时延迟递增
}
// 每次调用生成1张图片使用banana模型
Map<String, Object> apiResponse = realAIService.submitTextToImageTask(
prompt,
aspectRatio,
1, // 每次生成1张图片
hdMode // 使用任务的hdMode参数选择模型
);
// 检查API响应是否为空
if (apiResponse == null) {
retryCount++;
continue;
}
// 从API响应中提取图片URL
Object dataObj = apiResponse.get("data");
if (dataObj instanceof List) {
@SuppressWarnings("unchecked")
List<Map<String, Object>> data = (List<Map<String, Object>>) dataObj;
if (!data.isEmpty()) {
// 提取第一张图片的URL因为每次只生成1张
Map<String, Object> imageData = data.get(0);
if (imageData == null) {
retryCount++;
continue;
}
// 提取图片数据优先使用Base64其次使用URL
String imageBase64 = extractImageAsBase64(imageData);
if (imageBase64 != null && !imageBase64.isEmpty()) {
imageUrls.add(imageBase64);
imageGenerated = true;
long elapsed = System.currentTimeMillis() - startTime;
int progress = (int) ((i + 1) * 100.0 / DEFAULT_STORYBOARD_IMAGES);
logger.info("✓ 成功生成第{}/{}张分镜图(进度: {}%, 耗时: {}ms, Base64长度: {}",
i + 1, DEFAULT_STORYBOARD_IMAGES, progress, elapsed, imageBase64.length());
} else {
logger.warn("未能提取第{}张分镜图的数据", i + 1);
retryCount++;
}
} else {
logger.warn("第{}张图片URL为空跳过", i + 1);
retryCount++;
}
} else {
logger.warn("第{}张图片API响应data为空列表跳过", i + 1);
retryCount++;
}
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
logger.error("生成分镜图被中断: {}", taskId, e);
throw new RuntimeException("生成分镜图被中断", e);
} catch (Exception e) {
retryCount++;
if (retryCount > maxRetriesPerImage) {
logger.warn("生成第{}张分镜图失败,已重试{}次: {}, 继续生成其他图片",
i + 1, maxRetriesPerImage, e.getMessage());
// 记录详细错误信息仅在debug级别
if (logger.isDebugEnabled()) {
logger.debug("生成第{}张分镜图失败详情", i + 1, e);
}
} else {
logger.debug("生成第{}张分镜图失败,将重试: {}", i + 1, e.getMessage());
}
} else {
logger.warn("第{}张图片API响应data格式不正确不是列表跳过", i + 1);
}
// 在多次调用之间添加短暂延迟避免API限流
if (i < DEFAULT_STORYBOARD_IMAGES - 1) {
}
// 在多次调用之间添加短暂延迟避免API限流
if (i < DEFAULT_STORYBOARD_IMAGES - 1) {
try {
Thread.sleep(500); // 延迟500ms
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
throw new RuntimeException("生成分镜图被中断", e);
}
} catch (Exception e) {
logger.error("生成第{}张分镜图失败: {}", i + 1, e.getMessage());
// 继续生成其他图片,不因单张失败而终止整个流程
}
}
@@ -179,45 +248,214 @@ public class StoryboardVideoService {
throw new RuntimeException("未能从API响应中提取任何图片URL");
}
// 必须生成6张图片才能继续否则抛出异常
if (imageUrls.size() < DEFAULT_STORYBOARD_IMAGES) {
logger.warn("只生成了{}张图片,少于预期的{}张", imageUrls.size(), DEFAULT_STORYBOARD_IMAGES);
String errorMsg = String.format("只生成了%d张图片,少于预期的%d张无法拼接分镜图",
imageUrls.size(), DEFAULT_STORYBOARD_IMAGES);
logger.error(errorMsg);
throw new RuntimeException(errorMsg);
}
logger.info("成功获取{}张图片,开始拼接成分镜图网格...", imageUrls.size());
// 确保正好是6张图片如果多于6张只取前6张
if (imageUrls.size() > DEFAULT_STORYBOARD_IMAGES) {
logger.warn("生成了{}张图片,多于预期的{}张,只取前{}张进行拼接",
imageUrls.size(), DEFAULT_STORYBOARD_IMAGES, DEFAULT_STORYBOARD_IMAGES);
imageUrls = imageUrls.subList(0, DEFAULT_STORYBOARD_IMAGES);
}
// 拼接多张图片成网格
String mergedImageUrl = imageGridService.mergeImagesToGrid(imageUrls, 0); // 0表示自动计算列数
long totalTime = System.currentTimeMillis() - startTime;
logger.info("成功获取{}张图片(总耗时: {}ms开始验证并规范化图片格式...", imageUrls.size(), totalTime);
// 验证所有图片都是Base64格式带data URI前缀
// 参考sora2实现确保所有图片格式一致
long validateStartTime = System.currentTimeMillis();
List<String> validatedImages = validateAndNormalizeImages(imageUrls);
long validateTime = System.currentTimeMillis() - validateStartTime;
logger.debug("图片格式验证完成,耗时: {}ms", validateTime);
if (validatedImages.size() < DEFAULT_STORYBOARD_IMAGES) {
String errorMsg = String.format("验证后只有%d张图片少于预期的%d张无法拼接分镜图",
validatedImages.size(), DEFAULT_STORYBOARD_IMAGES);
logger.error(errorMsg);
throw new RuntimeException(errorMsg);
}
logger.info("开始拼接{}张图片成分镜图网格...", validatedImages.size());
// 拼接多张图片成网格此时确保有6张图片
// 使用验证后的图片列表都是Base64格式
long mergeStartTime = System.currentTimeMillis();
String mergedImageUrl = imageGridService.mergeImagesToGrid(validatedImages, 0); // 0表示自动计算列数
long mergeTime = System.currentTimeMillis() - mergeStartTime;
logger.info("图片网格拼接完成,耗时: {}ms", mergeTime);
// 检查拼接后的图片URL是否有效
if (mergedImageUrl == null || mergedImageUrl.isEmpty()) {
throw new RuntimeException("图片拼接失败: 返回的图片URL为空");
}
// 重新加载任务因为之前的flush可能使实体detached
task = taskRepository.findByTaskId(taskId)
.orElseThrow(() -> new RuntimeException("任务未找到: " + taskId));
// 保存单独的分镜图片Base64数组参考sora2实现
// validatedImages 已在上面定义并验证
// 将图片列表转换为JSON数组每张图片都是Base64格式带data URI前缀
String storyboardImagesJson = null;
try {
storyboardImagesJson = objectMapper.writeValueAsString(validatedImages);
logger.debug("分镜图片JSON长度: {} 字符", storyboardImagesJson.length());
if (logger.isTraceEnabled()) {
logger.trace("分镜图片JSON前500字符: {}",
storyboardImagesJson.length() > 500 ? storyboardImagesJson.substring(0, 500) + "..." : storyboardImagesJson);
}
} catch (Exception e) {
logger.error("转换分镜图片为JSON失败: {}", taskId, e);
// 如果转换失败,继续使用网格图
}
// 设置拼接后的结果图片URL
task.setResultUrl(mergedImageUrl);
task.setRealTaskId(taskId + "_image");
task.updateStatus(StoryboardVideoTask.TaskStatus.COMPLETED);
task.updateProgress(100);
// 只有在6张图片都生成并拼接完成后才保存结果图片URL
// 使用 TransactionTemplate 确保事务正确关闭
long saveStartTime = System.currentTimeMillis();
saveStoryboardImageResultWithTransactionTemplate(taskId, mergedImageUrl, storyboardImagesJson, validatedImages.size());
long saveTime = System.currentTimeMillis() - saveStartTime;
taskRepository.save(task);
long totalElapsed = System.currentTimeMillis() - startTime;
logger.info("✓ 分镜图生成完成: taskId={}, 共{}张图片,已拼接完成,总耗时: {}ms (生成: {}ms, 验证: {}ms, 拼接: {}ms, 保存: {}ms)",
taskId, validatedImages.size(), totalElapsed, totalTime, validateTime, mergeTime, saveTime);
logger.info("分镜图生成并拼接完成任务ID: {}, 共生成{}张图片", taskId, imageUrls.size());
// 不再自动生成视频,等待用户点击"开始生成"按钮
} catch (Exception e) {
logger.error("处理分镜视频任务失败: {}", taskId, e);
// 更新任务失败状态(使用 TransactionTemplate 确保事务正确关闭)
updateTaskStatusToFailedWithTransactionTemplate(taskId, e.getMessage());
}
}
/**
* 在异步方法中加载任务信息使用配置好的只读事务模板超时2秒确保快速完成
*/
private StoryboardVideoTask loadTaskInfoWithTransactionTemplate(String taskId) {
// 使用配置好的只读事务模板超时2秒确保快速完成
return readOnlyTransactionTemplate.execute(status -> {
return taskRepository.findByTaskId(taskId)
.orElseThrow(() -> new RuntimeException("任务未找到: " + taskId));
});
}
/**
* 在异步方法中更新任务状态为处理中使用配置好的异步事务模板超时3秒确保快速完成
*/
private void updateTaskStatusWithTransactionTemplate(String taskId) {
// 使用配置好的异步事务模板超时3秒确保快速完成
asyncTransactionTemplate.executeWithoutResult(status -> {
try {
StoryboardVideoTask task = taskRepository.findByTaskId(taskId)
.orElseThrow(() -> new RuntimeException("任务未找到: " + taskId));
task.updateStatus(StoryboardVideoTask.TaskStatus.FAILED);
task.setErrorMessage(e.getMessage());
task.updateStatus(StoryboardVideoTask.TaskStatus.PROCESSING);
taskRepository.save(task);
} catch (Exception ex) {
logger.error("更新任务失败状态失败: {}", taskId, ex);
logger.info("任务状态已更新为处理中: {}", taskId);
} catch (Exception e) {
logger.error("更新任务状态失败: {}", taskId, e);
status.setRollbackOnly();
throw e;
}
});
}
/**
* 在异步方法中保存分镜图结果使用配置好的异步事务模板超时3秒确保快速完成
* 参考sora2实现保存网格图和单独的分镜图片
*/
private void saveStoryboardImageResultWithTransactionTemplate(String taskId, String mergedImageUrl, String storyboardImagesJson, int validatedImageCount) {
asyncTransactionTemplate.executeWithoutResult(status -> {
try {
StoryboardVideoTask task = taskRepository.findByTaskId(taskId)
.orElseThrow(() -> new RuntimeException("任务未找到: " + taskId));
task.setResultUrl(mergedImageUrl); // 网格图(用于前端显示)
if (storyboardImagesJson != null && !storyboardImagesJson.isEmpty()) {
task.setStoryboardImages(storyboardImagesJson); // 单独的分镜图片(用于视频生成)
}
task.updateProgress(50); // 分镜图生成完成进度50%
taskRepository.save(task);
logger.debug("分镜图结果已保存: taskId={}, 图片数量={}", taskId, validatedImageCount);
} catch (Exception e) {
logger.error("保存分镜图结果失败: {}", taskId, e);
status.setRollbackOnly();
throw e;
}
});
}
/**
* 在异步方法中更新任务状态为失败使用配置好的异步事务模板超时3秒确保快速完成
*/
private void updateTaskStatusToFailedWithTransactionTemplate(String taskId, String errorMessage) {
try {
asyncTransactionTemplate.executeWithoutResult(status -> {
try {
StoryboardVideoTask task = taskRepository.findByTaskId(taskId)
.orElseThrow(() -> new RuntimeException("任务未找到: " + taskId));
task.updateStatus(StoryboardVideoTask.TaskStatus.FAILED);
task.setErrorMessage(errorMessage);
taskRepository.save(task);
} catch (Exception e) {
logger.error("更新任务失败状态失败: {}", taskId, e);
status.setRollbackOnly();
}
});
} catch (Exception e) {
logger.error("执行更新失败状态事务失败: {}", taskId, e);
}
}
/**
* 获取任务信息用于处理(只读事务,快速完成)
*/
@Transactional(readOnly = true)
public StoryboardVideoTask getTaskInfoForProcessing(String taskId) {
return taskRepository.findByTaskId(taskId)
.orElseThrow(() -> new RuntimeException("任务未找到: " + taskId));
}
/**
* 更新任务状态为处理中(单独的事务方法)
*/
@Transactional
public void updateTaskStatusToProcessing(String taskId) {
StoryboardVideoTask task = taskRepository.findByTaskId(taskId)
.orElseThrow(() -> new RuntimeException("任务未找到: " + taskId));
task.updateStatus(StoryboardVideoTask.TaskStatus.PROCESSING);
taskRepository.save(task);
}
/**
* 保存分镜图结果(单独的事务方法)
*/
@Transactional
public void saveStoryboardImageResult(String taskId, String mergedImageUrl, String storyboardImagesJson, int validatedImageCount) {
StoryboardVideoTask task = taskRepository.findByTaskId(taskId)
.orElseThrow(() -> new RuntimeException("任务未找到: " + taskId));
task.setResultUrl(mergedImageUrl); // 网格图(用于前端显示)
if (storyboardImagesJson != null && !storyboardImagesJson.isEmpty()) {
task.setStoryboardImages(storyboardImagesJson); // 单独的分镜图片(用于视频生成)
}
task.updateProgress(50); // 分镜图生成完成进度50%
// 状态保持 PROCESSING等待用户点击"开始生成"按钮后再生成视频
taskRepository.save(task);
logger.debug("分镜图结果已保存: taskId={}, 图片数量={}", taskId, validatedImageCount);
}
/**
* 更新任务状态为失败(单独的事务方法)
*/
@Transactional
public void updateTaskStatusToFailed(String taskId, String errorMessage) {
try {
StoryboardVideoTask task = taskRepository.findByTaskId(taskId)
.orElseThrow(() -> new RuntimeException("任务未找到: " + taskId));
task.updateStatus(StoryboardVideoTask.TaskStatus.FAILED);
task.setErrorMessage(errorMessage);
taskRepository.save(task);
} catch (Exception ex) {
logger.error("更新任务失败状态失败: {}", taskId, ex);
}
}
@@ -240,10 +478,400 @@ public class StoryboardVideoService {
return taskPage.getContent();
}
/**
* 开始生成视频(从分镜图生成视频)
* 用户点击"开始生成"按钮后调用
*/
@Transactional
public void startVideoGeneration(String taskId) {
try {
logger.debug("收到开始生成视频请求任务ID: {}", taskId);
// 重新加载任务
StoryboardVideoTask task = taskRepository.findByTaskId(taskId)
.orElseThrow(() -> new RuntimeException("任务不存在: " + taskId));
// 检查分镜图是否已生成
if (task.getResultUrl() == null || task.getResultUrl().isEmpty()) {
throw new RuntimeException("分镜图尚未生成,无法生成视频");
}
// 检查任务状态
if (task.getStatus() != StoryboardVideoTask.TaskStatus.PROCESSING) {
throw new RuntimeException("任务状态不正确,无法生成视频。当前状态: " + task.getStatus());
}
// 检查是否已经添加过视频生成任务(避免重复添加)
// 这里可以通过检查任务队列来判断,但为了简单,我们直接添加
// 如果已经存在TaskQueueService 会处理重复的情况
// 将视频生成任务添加到任务队列,由队列异步处理
logger.debug("开始将视频生成任务添加到队列: {}", taskId);
try {
taskQueueService.addStoryboardVideoTask(task.getUsername(), taskId);
// 任务状态保持 PROCESSING等待视频生成完成
} catch (Exception e) {
logger.error("添加分镜视频任务到队列失败: {}", taskId, e);
throw new RuntimeException("添加视频生成任务失败: " + e.getMessage());
}
} catch (Exception e) {
logger.error("开始生成视频失败: {}", taskId, e);
throw new RuntimeException("开始生成视频失败: " + e.getMessage());
}
}
/**
* 从图片数据中提取Base64格式的图片参考sora2实现
* 优先使用b64_json如果不存在则下载URL并转换为Base64
* 确保返回的格式为data:image/png;base64,{base64_str}
*/
private String extractImageAsBase64(Map<String, Object> imageData) {
if (imageData == null) {
return null;
}
// 优先使用b64_json参考banana实现
Object b64JsonObj = imageData.get("b64_json");
if (b64JsonObj != null) {
String base64Data = b64JsonObj.toString();
// 确保有data URI前缀参考Comfly.py 6236行
if (!base64Data.startsWith("data:")) {
return "data:image/png;base64," + base64Data;
} else {
return base64Data;
}
}
// 如果没有b64_json尝试下载URL并转换为Base64
Object urlObj = imageData.get("url");
if (urlObj != null) {
String imageUrl = urlObj.toString();
try {
// 下载图片并转换为Base64参考sora2实现
String base64Image = downloadImageAndConvertToBase64(imageUrl);
if (base64Image != null && !base64Image.isEmpty()) {
return base64Image;
}
} catch (Exception e) {
logger.warn("下载并转换图片失败: {}, 错误: {}", imageUrl, e.getMessage());
}
}
return null;
}
/**
* 压缩图片以减小体积
* @param originalImage 原始图片
* @param maxSize 最大尺寸(宽度或高度)
* @param quality 压缩质量0.0-1.0,未使用,保留用于未来扩展)
* @return 压缩后的图片
*/
private BufferedImage compressImage(BufferedImage originalImage, int maxSize, float quality) {
int originalWidth = originalImage.getWidth();
int originalHeight = originalImage.getHeight();
// 如果图片尺寸小于等于最大尺寸,直接返回
if (originalWidth <= maxSize && originalHeight <= maxSize) {
return originalImage;
}
// 计算缩放比例
double scale = Math.min((double) maxSize / originalWidth, (double) maxSize / originalHeight);
int newWidth = (int) (originalWidth * scale);
int newHeight = (int) (originalHeight * scale);
logger.debug("压缩图片: {}x{} -> {}x{} (缩放比例: {})",
originalWidth, originalHeight, newWidth, newHeight, scale);
// 创建缩放后的图片
BufferedImage compressedImage = new BufferedImage(newWidth, newHeight, BufferedImage.TYPE_INT_RGB);
Graphics2D g = compressedImage.createGraphics();
// 设置高质量缩放
g.setRenderingHint(RenderingHints.KEY_INTERPOLATION, RenderingHints.VALUE_INTERPOLATION_BILINEAR);
g.setRenderingHint(RenderingHints.KEY_RENDERING, RenderingHints.VALUE_RENDER_QUALITY);
g.setRenderingHint(RenderingHints.KEY_ANTIALIASING, RenderingHints.VALUE_ANTIALIAS_ON);
g.drawImage(originalImage, 0, 0, newWidth, newHeight, null);
g.dispose();
return compressedImage;
}
/**
* 压缩Base64图片如果图片过大
* @param base64Image Base64编码的图片带data URI前缀
* @param maxSize 最大尺寸(宽度或高度)
* @param quality JPEG压缩质量0.0-1.0
* @return 压缩后的Base64图片
*/
private String compressBase64Image(String base64Image, int maxSize, float quality) {
if (base64Image == null || !base64Image.startsWith("data:image")) {
return base64Image;
}
try {
// 提取Base64数据
String base64Data = base64Image.substring(base64Image.indexOf(",") + 1);
byte[] imageBytes = Base64.getDecoder().decode(base64Data);
BufferedImage image = ImageIO.read(new java.io.ByteArrayInputStream(imageBytes));
if (image == null) {
return base64Image;
}
// 检查是否需要压缩
int originalWidth = image.getWidth();
int originalHeight = image.getHeight();
int originalSize = imageBytes.length;
if (originalWidth <= maxSize && originalHeight <= maxSize && originalSize < 500 * 1024) {
// 图片已经足够小,不需要压缩
logger.debug("图片无需压缩: {}x{}, 大小: {} KB", originalWidth, originalHeight, originalSize / 1024);
return base64Image;
}
// 压缩图片
BufferedImage compressedImage = compressImage(image, maxSize, quality);
// 转换为JPEG格式的Base64
ByteArrayOutputStream baos = new ByteArrayOutputStream();
javax.imageio.ImageWriter writer = javax.imageio.ImageIO.getImageWritersByFormatName("jpg").next();
javax.imageio.ImageWriteParam param = writer.getDefaultWriteParam();
if (param.canWriteCompressed()) {
param.setCompressionMode(javax.imageio.ImageWriteParam.MODE_EXPLICIT);
param.setCompressionQuality(quality);
}
javax.imageio.IIOImage iioImage = new javax.imageio.IIOImage(compressedImage, null, null);
writer.setOutput(javax.imageio.ImageIO.createImageOutputStream(baos));
writer.write(null, iioImage, param);
writer.dispose();
byte[] compressedBytes = baos.toByteArray();
String compressedBase64 = Base64.getEncoder().encodeToString(compressedBytes);
double compressionRatio = (1.0 - (double) compressedBytes.length / originalSize) * 100;
logger.info("图片压缩完成: {}x{} -> {}x{}, 大小: {} KB -> {} KB (压缩率: {}%)",
originalWidth, originalHeight,
compressedImage.getWidth(), compressedImage.getHeight(),
originalSize / 1024, compressedBytes.length / 1024,
String.format("%.1f", compressionRatio));
return "data:image/jpeg;base64," + compressedBase64;
} catch (Exception e) {
logger.warn("压缩图片失败,使用原始图片: {}", e.getMessage());
return base64Image;
}
}
/**
* 验证并规范化图片格式参考sora2实现
* 确保所有图片都是Base64格式带data URI前缀
* 同时压缩过大的图片以减小请求体积
*/
private List<String> validateAndNormalizeImages(List<String> imageUrls) {
List<String> validatedImages = new ArrayList<>();
for (String img : imageUrls) {
if (img == null || img.isEmpty()) {
continue;
}
// 确保有data URI前缀参考Comfly.py 6236行
String normalizedImg = img;
if (!img.startsWith("data:")) {
// 如果没有前缀,尝试添加
if (img.contains(",")) {
// 如果已经有逗号提取Base64数据并添加前缀
String pureBase64Data = img.substring(img.indexOf(",") + 1);
normalizedImg = "data:image/png;base64," + pureBase64Data;
} else {
// 假设是纯Base64数据
normalizedImg = "data:image/png;base64," + img;
}
}
// 压缩图片以减小体积最大1024pxJPEG质量85%
// 这样可以显著减小请求体大小从17MB降低到几MB
normalizedImg = compressBase64Image(normalizedImg, 1024, 0.85f);
validatedImages.add(normalizedImg);
}
logger.debug("验证并规范化图片完成,原始数量: {}, 验证后数量: {}", imageUrls.size(), validatedImages.size());
return validatedImages;
}
/**
* 下载图片并转换为Base64格式参考sora2实现
* 返回格式data:image/png;base64,{base64_str}
* 添加超时控制和重试机制
*/
private String downloadImageAndConvertToBase64(String imageUrl) {
if (imageUrl == null || imageUrl.isEmpty()) {
return null;
}
// 如果已经是Base64格式直接返回
if (imageUrl.startsWith("data:image")) {
return imageUrl;
}
// 重试机制最多重试2次
int maxRetries = 2;
int retryCount = 0;
long connectTimeout = 10000; // 10秒连接超时
long readTimeout = 30000; // 30秒读取超时
while (retryCount <= maxRetries) {
try {
if (retryCount > 0) {
logger.debug("重试下载图片(第{}次): {}", retryCount, imageUrl);
Thread.sleep(1000 * retryCount); // 重试延迟递增
}
// 下载图片使用URI避免deprecated警告
URI uri = new URI(imageUrl);
URL url = uri.toURL();
// 使用HttpURLConnection以便设置超时
HttpURLConnection connection = (HttpURLConnection) url.openConnection();
connection.setConnectTimeout((int) connectTimeout);
connection.setReadTimeout((int) readTimeout);
connection.setRequestMethod("GET");
connection.setRequestProperty("User-Agent", "Mozilla/5.0");
connection.setRequestProperty("Accept", "image/*");
BufferedImage image;
try (InputStream in = connection.getInputStream()) {
image = ImageIO.read(in);
} finally {
connection.disconnect();
}
if (image == null) {
logger.warn("无法读取图片: {}", imageUrl);
retryCount++;
continue;
}
// 压缩图片以减小体积(限制最大尺寸和压缩质量)
BufferedImage compressedImage = compressImage(image, 1024, 0.85f); // 最大1024px质量85%
// 转换为JPEG格式的Base64JPEG压缩率更高体积更小
ByteArrayOutputStream baos = new ByteArrayOutputStream();
javax.imageio.ImageWriter writer = javax.imageio.ImageIO.getImageWritersByFormatName("jpg").next();
javax.imageio.ImageWriteParam param = writer.getDefaultWriteParam();
if (param.canWriteCompressed()) {
param.setCompressionMode(javax.imageio.ImageWriteParam.MODE_EXPLICIT);
param.setCompressionQuality(0.85f); // JPEG质量85%
}
javax.imageio.IIOImage iioImage = new javax.imageio.IIOImage(compressedImage, null, null);
writer.setOutput(javax.imageio.ImageIO.createImageOutputStream(baos));
writer.write(null, iioImage, param);
writer.dispose();
byte[] imageBytes = baos.toByteArray();
String base64 = Base64.getEncoder().encodeToString(imageBytes);
// 返回带data URI前缀的Base64字符串使用JPEG格式以减小体积
logger.debug("成功下载并转换图片: {} (原始: {}x{}, 压缩后: {} KB)",
imageUrl, image.getWidth(), image.getHeight(), imageBytes.length / 1024);
return "data:image/jpeg;base64," + base64;
} catch (java.net.SocketTimeoutException | java.net.ConnectException e) {
retryCount++;
if (retryCount > maxRetries) {
logger.error("下载图片超时或连接失败(已重试{}次): {}", maxRetries, imageUrl, e);
return null;
}
logger.debug("下载图片超时,将重试: {}", imageUrl);
} catch (IOException e) {
retryCount++;
if (retryCount > maxRetries) {
logger.error("下载图片失败(已重试{}次): {}", maxRetries, imageUrl, e);
return null;
}
logger.debug("下载图片失败,将重试: {}", imageUrl);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
logger.error("下载图片被中断: {}", imageUrl, e);
return null;
} catch (Exception e) {
logger.error("转换图片为Base64失败: {}", imageUrl, e);
return null;
}
}
return null;
}
/**
* 生成任务ID
*/
private String generateTaskId() {
return "sb_" + UUID.randomUUID().toString().replace("-", "").substring(0, 16);
}
/**
* 检查并处理超时的分镜图生成任务
* 如果任务状态为PROCESSINGrealTaskId为空说明还在生成分镜图阶段且创建时间超过10分钟则标记为超时
* 注意如果任务已经有resultUrl分镜图已生成即使超时也不标记为失败因为分镜图已经成功生成
*/
@Transactional
public int checkAndHandleTimeoutTasks() {
try {
// 计算超时时间点10分钟前
LocalDateTime timeoutTime = LocalDateTime.now().minusMinutes(10);
// 查找超时的任务状态为PROCESSINGrealTaskId为空创建时间超过10分钟
List<StoryboardVideoTask> timeoutTasks = taskRepository.findTimeoutTasks(
StoryboardVideoTask.TaskStatus.PROCESSING,
timeoutTime
);
if (timeoutTasks.isEmpty()) {
return 0;
}
logger.warn("发现 {} 个可能超时的分镜图生成任务,开始检查", timeoutTasks.size());
int handledCount = 0;
int skippedCount = 0;
for (StoryboardVideoTask task : timeoutTasks) {
try {
// 检查任务是否已经有resultUrl分镜图已生成
// 如果有resultUrl说明分镜图已经成功生成不应该被标记为超时失败
if (task.getResultUrl() != null && !task.getResultUrl().isEmpty()) {
logger.debug("任务 {} 已有resultUrl分镜图已生成跳过超时标记", task.getTaskId());
skippedCount++;
continue;
}
// 更新任务状态为失败
task.updateStatus(StoryboardVideoTask.TaskStatus.FAILED);
task.setErrorMessage("分镜图生成超时任务创建后超过10分钟仍未完成");
taskRepository.save(task);
logger.warn("分镜图生成任务超时,已标记为失败: taskId={}", task.getTaskId());
handledCount++;
} catch (Exception e) {
logger.error("处理超时分镜图生成任务失败: taskId={}", task.getTaskId(), e);
}
}
if (handledCount > 0 || skippedCount > 0) {
logger.info("处理超时分镜图生成任务完成,失败: {}/{},跳过(已生成): {}",
handledCount, timeoutTasks.size(), skippedCount);
}
return handledCount;
} catch (Exception e) {
logger.error("检查超时分镜图生成任务失败", e);
return 0;
}
}
}

View File

@@ -9,6 +9,7 @@ import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.scheduling.annotation.Scheduled;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Propagation;
import org.springframework.transaction.annotation.Transactional;
import com.example.demo.model.TaskStatus;
@@ -79,75 +80,72 @@ public class TaskStatusPollingService {
/**
* 轮询单个任务状态
* 注意此方法明确禁用事务因为长时间运行的外部API调用会占用数据库连接
*/
@Transactional
@Transactional(propagation = Propagation.NOT_SUPPORTED)
public void pollTaskStatus(TaskStatus task) {
logger.info("轮询任务状态: taskId={}, externalTaskId={}", task.getTaskId(), task.getExternalTaskId());
try {
// 调用外部API查询状态
HttpResponse<String> response = Unirest.get(apiBaseUrl + "/v1/videos")
// 调用外部API查询状态(长时间运行,不在事务中)
HttpResponse<String> response = Unirest.post(apiBaseUrl + "/v1/videos")
.header("Authorization", "Bearer " + apiKey)
.queryString("task_id", task.getExternalTaskId())
.field("task_id", task.getExternalTaskId())
.asString();
if (response.getStatus() == 200) {
JsonNode responseJson = objectMapper.readTree(response.getBody());
updateTaskStatus(task, responseJson);
// 更新任务状态(使用单独的事务方法)
updateTaskStatusWithTransaction(task, responseJson);
} else {
logger.warn("查询任务状态失败: taskId={}, status={}, response={}",
task.getTaskId(), response.getStatus(), response.getBody());
task.incrementPollCount();
taskStatusRepository.save(task);
// 更新轮询次数(使用单独的事务方法)
incrementPollCountWithTransaction(task);
}
} catch (Exception e) {
logger.error("轮询任务状态异常: taskId={}, error={}", task.getTaskId(), e.getMessage(), e);
task.incrementPollCount();
taskStatusRepository.save(task);
// 更新轮询次数(使用单独的事务方法)
incrementPollCountWithTransaction(task);
}
}
/**
* 更新任务状态(单独的事务方法)
*/
@Transactional
public void updateTaskStatusWithTransaction(TaskStatus task, JsonNode responseJson) {
updateTaskStatus(task, responseJson);
}
/**
* 增加轮询次数(单独的事务方法)
*/
@Transactional
public void incrementPollCountWithTransaction(TaskStatus task) {
task.incrementPollCount();
taskStatusRepository.save(task);
}
/**
* 更新任务状态
*/
private void updateTaskStatus(TaskStatus task, JsonNode responseJson) {
try {
// 检查base_resp状态
JsonNode baseResp = responseJson.path("base_resp");
if (!baseResp.isMissingNode() && baseResp.path("status_code").asInt() != 0) {
String errorMsg = baseResp.path("status_msg").asText("Unknown error");
task.markAsFailed(errorMsg);
logger.warn("API返回错误: taskId={}, error={}", task.getTaskId(), errorMsg);
taskStatusRepository.save(task);
return;
}
String status = responseJson.path("status").asText();
int progress = responseJson.path("progress").asInt(0);
String resultUrl = null;
String resultUrl = responseJson.path("result_url").asText();
String errorMessage = responseJson.path("error_message").asText();
task.incrementPollCount();
task.setProgress(progress);
switch (status.toLowerCase()) {
case "completed":
case "success":
// 获取file_id并获取视频URL
String fileId = responseJson.path("file_id").asText();
if (!fileId.isEmpty()) {
resultUrl = getVideoUrlFromFileId(fileId);
if (resultUrl != null) {
task.markAsCompleted(resultUrl);
logger.info("任务完成: taskId={}, resultUrl={}", task.getTaskId(), resultUrl);
} else {
task.markAsFailed("无法获取视频URL");
logger.warn("任务完成但无法获取视频URL: taskId={}, fileId={}", task.getTaskId(), fileId);
}
} else {
task.markAsFailed("任务完成但未返回文件ID");
logger.warn("任务完成但未返回文件ID: taskId={}", task.getTaskId());
}
task.markAsCompleted(resultUrl);
logger.info("任务完成: taskId={}, resultUrl={}", task.getTaskId(), resultUrl);
break;
case "failed":
@@ -158,7 +156,6 @@ public class TaskStatusPollingService {
case "processing":
case "in_progress":
case "pending":
task.setStatus(TaskStatus.Status.PROCESSING);
logger.info("任务处理中: taskId={}, progress={}%", task.getTaskId(), progress);
break;
@@ -175,38 +172,6 @@ public class TaskStatusPollingService {
}
}
/**
* 根据file_id获取视频URL
*/
private String getVideoUrlFromFileId(String fileId) {
try {
HttpResponse<String> response = Unirest.get(apiBaseUrl + "/minimax/v1/files/retrieve")
.header("Authorization", "Bearer " + apiKey)
.queryString("file_id", fileId)
.asString();
if (response.getStatus() == 200) {
JsonNode responseJson = objectMapper.readTree(response.getBody());
JsonNode fileNode = responseJson.path("file");
if (!fileNode.isMissingNode()) {
String downloadUrl = fileNode.path("download_url").asText();
if (!downloadUrl.isEmpty()) {
logger.info("成功获取视频URL: fileId={}, url={}", fileId, downloadUrl);
return downloadUrl;
}
}
}
logger.warn("获取视频URL失败: fileId={}, status={}, response={}",
fileId, response.getStatus(), response.getBody());
return null;
} catch (Exception e) {
logger.error("获取视频URL时发生错误: fileId={}, error={}", fileId, e.getMessage(), e);
return null;
}
}
/**
* 处理超时任务
*/

View File

@@ -22,9 +22,9 @@ import com.example.demo.repository.TextToVideoTaskRepository;
/**
* 文生视频服务类
* 注意:不在类级别使用 @Transactional因为某些方法需要禁用事务如长时间运行的外部API调用
*/
@Service
@Transactional
public class TextToVideoService {
private static final Logger logger = LoggerFactory.getLogger(TextToVideoService.class);
@@ -44,6 +44,7 @@ public class TextToVideoService {
/**
* 创建文生视频任务
*/
@Transactional
public TextToVideoTask createTask(String username, String prompt, String aspectRatio, int duration, boolean hdMode) {
try {
// 验证参数
@@ -87,21 +88,39 @@ public class TextToVideoService {
/**
* 使用真实API处理任务
*/
@Async
@Async("taskExecutor")
public CompletableFuture<Void> processTaskWithRealAPI(TextToVideoTask task) {
try {
logger.info("开始使用真实API处理文生视频任务: {}", task.getTaskId());
// 重新从数据库加载任务,确保获取最新状态
TextToVideoTask currentTask = taskRepository.findByTaskId(task.getTaskId())
.orElseThrow(() -> new RuntimeException("任务不存在: " + task.getTaskId()));
// 检查任务是否已经有 realTaskId如果有说明已经提交过了不应该再次处理
if (currentTask.getRealTaskId() != null && !currentTask.getRealTaskId().isEmpty()) {
logger.warn("文生视频任务 {} 已经有 realTaskId{}),说明已经提交过了,跳过处理",
task.getTaskId(), currentTask.getRealTaskId());
return CompletableFuture.completedFuture(null);
}
// 检查任务状态如果已经不是PENDING说明已经被其他线程处理了
if (currentTask.getStatus() != TextToVideoTask.TaskStatus.PENDING) {
logger.warn("文生视频任务 {} 状态已不是PENDING当前状态: {}),跳过处理,可能已被其他线程处理",
task.getTaskId(), currentTask.getStatus());
return CompletableFuture.completedFuture(null);
}
// 更新任务状态为处理中
task.updateStatus(TextToVideoTask.TaskStatus.PROCESSING);
taskRepository.save(task);
currentTask.updateStatus(TextToVideoTask.TaskStatus.PROCESSING);
taskRepository.save(currentTask);
// 调用真实API提交任务
Map<String, Object> apiResponse = realAIService.submitTextToVideoTask(
task.getPrompt(),
task.getAspectRatio(),
String.valueOf(task.getDuration()),
task.isHdMode()
currentTask.getPrompt(),
currentTask.getAspectRatio(),
String.valueOf(currentTask.getDuration()),
currentTask.isHdMode()
);
// 从API响应中提取真实任务ID
@@ -138,20 +157,25 @@ public class TextToVideoService {
// 如果找到了真实任务ID保存到数据库
if (realTaskId != null) {
task.setRealTaskId(realTaskId);
taskRepository.save(task);
// 重新加载任务以确保获取最新状态
currentTask = taskRepository.findByTaskId(task.getTaskId())
.orElseThrow(() -> new RuntimeException("任务不存在: " + task.getTaskId()));
currentTask.setRealTaskId(realTaskId);
taskRepository.save(currentTask);
logger.info("真实任务ID已保存: {} -> {}", task.getTaskId(), realTaskId);
} else {
// 如果没有找到任务ID说明任务提交失败
logger.error("任务提交失败未从API响应中获取到任务ID");
task.updateStatus(TextToVideoTask.TaskStatus.FAILED);
task.setErrorMessage("任务提交失败API未返回有效的任务ID");
taskRepository.save(task);
currentTask = taskRepository.findByTaskId(task.getTaskId())
.orElseThrow(() -> new RuntimeException("任务不存在: " + task.getTaskId()));
currentTask.updateStatus(TextToVideoTask.TaskStatus.FAILED);
currentTask.setErrorMessage("任务提交失败API未返回有效的任务ID");
taskRepository.save(currentTask);
return CompletableFuture.completedFuture(null); // 直接返回,不进行轮询
}
// 开始轮询真实任务状态
pollRealTaskStatus(task);
pollRealTaskStatus(currentTask);
} catch (Exception e) {
logger.error("使用真实API处理文生视频任务失败: {}", task.getTaskId(), e);
@@ -161,10 +185,15 @@ public class TextToVideoService {
}
try {
// 更新状态为失败
task.updateStatus(TextToVideoTask.TaskStatus.FAILED);
task.setErrorMessage(e.getMessage());
taskRepository.save(task);
// 重新加载任务以确保获取最新状态
TextToVideoTask currentTask = taskRepository.findByTaskId(task.getTaskId())
.orElse(null);
if (currentTask != null) {
// 更新状态为失败
currentTask.updateStatus(TextToVideoTask.TaskStatus.FAILED);
currentTask.setErrorMessage(e.getMessage());
taskRepository.save(currentTask);
}
} catch (Exception saveException) {
logger.error("保存失败状态时出错: {}", task.getTaskId(), saveException);
}
@@ -361,31 +390,6 @@ public class TextToVideoService {
return taskRepository.findByTaskIdAndUsername(taskId, username).orElse(null);
}
/**
* 取消任务
*/
@Transactional
public boolean cancelTask(String taskId, String username) {
// 使用悲观锁避免并发问题
TextToVideoTask task = taskRepository.findByTaskId(taskId).orElse(null);
if (task == null || task.getUsername() == null || !task.getUsername().equals(username)) {
return false;
}
// 检查任务状态只有PENDING和PROCESSING状态的任务才能取消
if (task.getStatus() == TextToVideoTask.TaskStatus.PENDING ||
task.getStatus() == TextToVideoTask.TaskStatus.PROCESSING) {
task.updateStatus(TextToVideoTask.TaskStatus.CANCELLED);
task.setErrorMessage("用户取消了任务");
taskRepository.save(task);
logger.info("文生视频任务已取消: {}, 用户: {}", taskId, username);
return true;
}
return false;
}
/**
* 获取待处理任务列表
@@ -416,4 +420,64 @@ public class TextToVideoService {
LocalDateTime expiredDate = LocalDateTime.now().minusDays(30);
return taskRepository.deleteExpiredTasks(expiredDate);
}
/**
* 检查并处理超时的文生视频任务
* 如果任务状态为PROCESSING且创建时间超过10分钟则标记为超时
* 注意如果任务已经有resultUrl视频已生成即使超时也不标记为失败因为视频已经成功生成
*/
@Transactional
public int checkAndHandleTimeoutTasks() {
try {
// 计算超时时间点10分钟前
LocalDateTime timeoutTime = LocalDateTime.now().minusMinutes(10);
// 查找超时的任务状态为PROCESSING创建时间超过10分钟
List<TextToVideoTask> timeoutTasks = taskRepository.findTimeoutTasks(
TextToVideoTask.TaskStatus.PROCESSING,
timeoutTime
);
if (timeoutTasks.isEmpty()) {
return 0;
}
logger.warn("发现 {} 个可能超时的文生视频任务,开始检查", timeoutTasks.size());
int handledCount = 0;
int skippedCount = 0;
for (TextToVideoTask task : timeoutTasks) {
try {
// 检查任务是否已经有resultUrl视频已生成
// 如果有resultUrl说明视频已经成功生成不应该被标记为超时失败
if (task.getResultUrl() != null && !task.getResultUrl().isEmpty()) {
logger.debug("任务 {} 已有resultUrl视频已生成跳过超时标记", task.getTaskId());
skippedCount++;
continue;
}
// 更新任务状态为失败
task.updateStatus(TextToVideoTask.TaskStatus.FAILED);
task.setErrorMessage("文生视频任务超时任务创建后超过10分钟仍未完成");
taskRepository.save(task);
logger.warn("文生视频任务超时,已标记为失败: taskId={}", task.getTaskId());
handledCount++;
} catch (Exception e) {
logger.error("处理超时文生视频任务失败: taskId={}", task.getTaskId(), e);
}
}
if (handledCount > 0 || skippedCount > 0) {
logger.info("处理超时文生视频任务完成,失败: {}/{},跳过(已生成): {}",
handledCount, timeoutTasks.size(), skippedCount);
}
return handledCount;
} catch (Exception e) {
logger.error("检查超时文生视频任务失败", e);
return 0;
}
}
}

View File

@@ -0,0 +1,189 @@
package com.example.demo.service;
import java.time.LocalDate;
import java.time.LocalDateTime;
import java.util.List;
import java.util.Optional;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.scheduling.annotation.Scheduled;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
import com.example.demo.model.UserActivityStats;
import com.example.demo.repository.OrderRepository;
import com.example.demo.repository.TaskQueueRepository;
import com.example.demo.repository.UserActivityStatsRepository;
import com.example.demo.repository.UserRepository;
/**
* 用户活跃度统计服务
* 负责计算和保存日活用户数据
*/
@Service
public class UserActivityStatsService {
private static final Logger logger = LoggerFactory.getLogger(UserActivityStatsService.class);
@Autowired
private UserActivityStatsRepository userActivityStatsRepository;
@Autowired
private UserRepository userRepository;
@Autowired
private OrderRepository orderRepository;
@Autowired
private TaskQueueRepository taskQueueRepository;
/**
* 计算并保存指定日期的日活用户数
* 日活用户定义:当天有订单或任务操作的不同用户数
*
* @param date 要统计的日期
*/
@Transactional
public void calculateAndSaveDailyActiveUsers(LocalDate date) {
try {
// 检查是否已经存在该日期的统计数据
List<UserActivityStats> existingList = userActivityStatsRepository.findByActivityDateBetween(date, date);
Optional<UserActivityStats> existing = existingList.stream()
.filter(stats -> stats.getActivityDate().equals(date))
.findFirst();
if (existing.isPresent()) {
logger.info("日期 {} 的统计数据已存在,跳过计算", date);
return;
}
LocalDateTime startOfDay = date.atStartOfDay();
LocalDateTime endOfDay = date.atTime(23, 59, 59);
// 方法1统计当天有订单的不同用户ID
java.util.Set<Long> activeUserIds = orderRepository.findAll().stream()
.filter(order -> order.getCreatedAt() != null &&
order.getCreatedAt().isAfter(startOfDay) &&
order.getCreatedAt().isBefore(endOfDay))
.map(order -> order.getUser() != null ? order.getUser().getId() : null)
.filter(userId -> userId != null)
.collect(java.util.stream.Collectors.toSet());
// 方法2统计当天有任务操作的不同用户通过username查找用户ID
java.util.Set<String> activeUsernames = taskQueueRepository.findAll().stream()
.filter(task -> task.getCreatedAt() != null &&
task.getCreatedAt().isAfter(startOfDay) &&
task.getCreatedAt().isBefore(endOfDay))
.map(task -> task.getUsername())
.filter(username -> username != null && !username.isEmpty())
.collect(java.util.stream.Collectors.toSet());
// 将username转换为userID并合并
activeUsernames.forEach(username -> {
userRepository.findByUsername(username).ifPresent(user -> {
activeUserIds.add(user.getId());
});
});
// 方法3统计当天登录的用户数如果有登录记录表
userRepository.findAll().stream()
.filter(user -> user.getLastLoginAt() != null &&
user.getLastLoginAt().isAfter(startOfDay) &&
user.getLastLoginAt().isBefore(endOfDay))
.forEach(user -> activeUserIds.add(user.getId()));
// 日活用户 = 当天有订单、有任务操作或登录的不同用户数(已去重)
long dailyActiveUsers = activeUserIds.size();
// 统计当天新增用户数
long newUsers = userRepository.findAll().stream()
.filter(user -> user.getCreatedAt() != null &&
user.getCreatedAt().isAfter(startOfDay) &&
user.getCreatedAt().isBefore(endOfDay))
.count();
// 创建或更新统计数据
UserActivityStats stats = existing.orElse(new UserActivityStats());
stats.setActivityDate(date);
stats.setDailyActiveUsers((int) dailyActiveUsers);
stats.setNewUsers((int) newUsers);
stats.setReturningUsers((int) (dailyActiveUsers - newUsers));
// 计算月活用户(当月至少活跃一次的用户数)
LocalDate monthStart = date.withDayOfMonth(1);
long monthlyActiveUsers = userRepository.findAll().stream()
.filter(user -> {
// 检查用户是否在当月有活动(有订单或任务)
boolean hasOrder = orderRepository.findAll().stream()
.anyMatch(order -> order.getUser() != null &&
order.getUser().getId().equals(user.getId()) &&
order.getCreatedAt() != null &&
order.getCreatedAt().isAfter(monthStart.atStartOfDay()) &&
order.getCreatedAt().isBefore(endOfDay));
boolean hasTask = taskQueueRepository.findAll().stream()
.anyMatch(task -> task.getUsername() != null &&
task.getUsername().equals(user.getUsername()) &&
task.getCreatedAt() != null &&
task.getCreatedAt().isAfter(monthStart.atStartOfDay()) &&
task.getCreatedAt().isBefore(endOfDay));
return hasOrder || hasTask;
})
.count();
stats.setMonthlyActiveUsers((int) monthlyActiveUsers);
userActivityStatsRepository.save(stats);
logger.info("日期 {} 的日活用户统计完成: 日活={}, 新增={}, 月活={}",
date, dailyActiveUsers, newUsers, monthlyActiveUsers);
} catch (Exception e) {
logger.error("计算日期 {} 的日活用户统计失败", date, e);
}
}
/**
* 定时任务每天凌晨2点计算前一天的日活用户数据
*/
@Scheduled(cron = "0 0 2 * * ?")
public void calculateYesterdayDailyActiveUsers() {
LocalDate yesterday = LocalDate.now().minusDays(1);
logger.info("开始计算日期 {} 的日活用户统计", yesterday);
calculateAndSaveDailyActiveUsers(yesterday);
}
/**
* 定时任务每天凌晨3点计算当天的日活用户数据实时更新
*/
@Scheduled(cron = "0 0 3 * * ?")
public void calculateTodayDailyActiveUsers() {
LocalDate today = LocalDate.now();
logger.info("开始计算日期 {} 的日活用户统计", today);
calculateAndSaveDailyActiveUsers(today);
}
/**
* 手动触发:计算指定日期的日活用户数据
*/
public void calculateDailyActiveUsersForDate(LocalDate date) {
calculateAndSaveDailyActiveUsers(date);
}
/**
* 批量计算计算最近N天的日活用户数据
*/
@Transactional
public void calculateDailyActiveUsersForRecentDays(int days) {
LocalDate today = LocalDate.now();
for (int i = 0; i < days; i++) {
LocalDate date = today.minusDays(i);
calculateAndSaveDailyActiveUsers(date);
}
logger.info("完成最近 {} 天的日活用户统计", days);
}
}

View File

@@ -19,11 +19,18 @@ public class UserService {
private final UserRepository userRepository;
private final PasswordEncoder passwordEncoder;
private final PointsFreezeRecordRepository pointsFreezeRecordRepository;
private final com.example.demo.repository.OrderRepository orderRepository;
private final com.example.demo.repository.PaymentRepository paymentRepository;
public UserService(UserRepository userRepository, PasswordEncoder passwordEncoder, PointsFreezeRecordRepository pointsFreezeRecordRepository) {
public UserService(UserRepository userRepository, PasswordEncoder passwordEncoder,
PointsFreezeRecordRepository pointsFreezeRecordRepository,
com.example.demo.repository.OrderRepository orderRepository,
com.example.demo.repository.PaymentRepository paymentRepository) {
this.userRepository = userRepository;
this.passwordEncoder = passwordEncoder;
this.pointsFreezeRecordRepository = pointsFreezeRecordRepository;
this.orderRepository = orderRepository;
this.paymentRepository = paymentRepository;
}
@Transactional
@@ -206,8 +213,14 @@ public class UserService {
PointsFreezeRecord record = pointsFreezeRecordRepository.findByTaskId(taskId)
.orElseThrow(() -> new RuntimeException("找不到冻结记录: " + taskId));
// 如果已经扣除过,直接返回,避免重复扣除(防止多线程并发处理)
if (record.getStatus() == PointsFreezeRecord.FreezeStatus.DEDUCTED) {
logger.info("冻结记录 {} 已扣除,跳过重复扣除", taskId);
return;
}
if (record.getStatus() != PointsFreezeRecord.FreezeStatus.FROZEN) {
throw new RuntimeException("冻结记录状态不正确: " + record.getStatus());
throw new RuntimeException("冻结记录状态不正确: " + record.getStatus() + ",期望状态: FROZEN");
}
User user = userRepository.findByUsername(record.getUsername())
@@ -339,4 +352,234 @@ public class UserService {
.orElseThrow(() -> new RuntimeException("用户不存在"));
return user.getPoints();
}
/**
* 获取积分使用历史(充值和使用记录)
* 包括:订单充值记录和积分消耗记录
*/
@Transactional(readOnly = true)
public java.util.List<java.util.Map<String, Object>> getPointsHistory(String username, int page, int size) {
User user = userRepository.findByUsername(username)
.orElseThrow(() -> new RuntimeException("用户不存在"));
java.util.List<java.util.Map<String, Object>> history = new java.util.ArrayList<>();
// 1. 获取成功支付的记录(充值记录)
// 注意:积分是在支付成功时通过 PaymentService.addPointsForPayment 添加的
// 所以应该从支付记录中获取充值记录,而不是从订单中获取
java.util.List<com.example.demo.model.Payment> allPayments = paymentRepository
.findByUserIdOrderByCreatedAtDesc(user.getId());
logger.info("获取用户 {} (ID: {}) 的积分历史记录", username, user.getId());
logger.info("用户 {} 的支付记录总数: {}", username, allPayments.size());
// 打印所有支付记录的状态,用于调试
for (com.example.demo.model.Payment p : allPayments) {
logger.info("支付记录 ID: {}, 状态: {}, 金额: {}, 描述: {}",
p.getId(), p.getStatus(), p.getAmount(), p.getDescription());
}
java.util.List<com.example.demo.model.Payment> successfulPayments = allPayments
.stream()
.filter(payment -> payment.getStatus() == com.example.demo.model.PaymentStatus.SUCCESS)
.collect(java.util.stream.Collectors.toList());
logger.info("用户 {} 的成功支付记录数: {}", username, successfulPayments.size());
for (com.example.demo.model.Payment payment : successfulPayments) {
// 从支付记录中提取积分数量(使用与 PaymentService.addPointsForPayment 相同的逻辑)
Integer points = extractPointsFromPayment(payment);
logger.info("处理支付记录 ID: {}, 金额: {}, 描述: {}, 提取的积分: {}",
payment.getId(), payment.getAmount(), payment.getDescription(), points);
// 即使没有提取到积分,也显示充值记录(可能金额不在套餐范围内,但用户确实支付了)
// 如果提取到积分使用提取的积分否则显示0积分表示支付成功但未获得积分
Integer displayPoints = (points != null && points > 0) ? points : 0;
java.util.Map<String, Object> record = new java.util.HashMap<>();
record.put("type", "充值");
String description = payment.getDescription() != null ? payment.getDescription() : "支付充值";
if (displayPoints == 0 && points == null) {
// 如果未提取到积分,在描述中说明
description = description + "(金额不在套餐范围内,未获得积分)";
}
record.put("description", description);
record.put("points", displayPoints);
record.put("time", payment.getPaidAt() != null ? payment.getPaidAt() : payment.getCreatedAt());
record.put("orderId", payment.getOrderId());
record.put("paymentId", payment.getId());
history.add(record);
logger.info("✓ 添加充值记录: {} 积分, 时间: {}, 描述: {}", displayPoints, record.get("time"), description);
}
logger.info("用户 {} 的充值记录数: {}", username, history.size());
// 2. 也检查已完成订单(作为补充,以防有订单但没有支付记录的情况)
java.util.List<com.example.demo.model.Order> completedOrders = orderRepository.findByUserIdAndStatus(
user.getId(),
com.example.demo.model.OrderStatus.COMPLETED
);
for (com.example.demo.model.Order order : completedOrders) {
// 检查是否已经在支付记录中处理过(避免重复)
boolean alreadyProcessed = successfulPayments.stream()
.anyMatch(p -> p.getOrderId() != null && p.getOrderId().equals(order.getOrderNumber()));
if (!alreadyProcessed) {
// 从订单描述或订单项中提取积分数量
Integer points = extractPointsFromOrder(order);
if (points != null && points > 0) {
java.util.Map<String, Object> record = new java.util.HashMap<>();
record.put("type", "充值");
record.put("description", "订单充值 - " + (order.getDescription() != null ? order.getDescription() : ""));
record.put("points", points);
record.put("time", order.getPaidAt() != null ? order.getPaidAt() : order.getCreatedAt());
record.put("orderNumber", order.getOrderNumber());
record.put("orderType", order.getOrderType() != null ? order.getOrderType().name() : "");
history.add(record);
}
}
}
// 3. 获取积分冻结记录(使用记录)- 只获取已扣除的记录
java.util.List<PointsFreezeRecord> deductedRecords = pointsFreezeRecordRepository
.findByUsernameOrderByCreatedAtDesc(username)
.stream()
.filter(record -> record.getStatus() == PointsFreezeRecord.FreezeStatus.DEDUCTED)
.collect(java.util.stream.Collectors.toList());
for (PointsFreezeRecord record : deductedRecords) {
java.util.Map<String, Object> historyRecord = new java.util.HashMap<>();
historyRecord.put("type", "消耗");
historyRecord.put("description", record.getTaskType().getDescription() + " - " +
(record.getFreezeReason() != null ? record.getFreezeReason() : "任务消耗"));
historyRecord.put("points", -record.getFreezePoints()); // 负数表示消耗
historyRecord.put("time", record.getCompletedAt() != null ? record.getCompletedAt() : record.getCreatedAt());
historyRecord.put("taskId", record.getTaskId());
historyRecord.put("taskType", record.getTaskType().name());
history.add(historyRecord);
}
// 4. 按时间倒序排序
history.sort((a, b) -> {
java.time.LocalDateTime timeA = (java.time.LocalDateTime) a.get("time");
java.time.LocalDateTime timeB = (java.time.LocalDateTime) b.get("time");
return timeB.compareTo(timeA); // 倒序
});
// 5. 分页处理
int start = page * size;
int end = Math.min(start + size, history.size());
if (start >= history.size()) {
return new java.util.ArrayList<>();
}
return history.subList(start, end);
}
/**
* 从订单中提取积分数量
* 这里需要根据实际业务逻辑调整
* 假设订单描述或订单项中包含积分信息
*/
private Integer extractPointsFromOrder(com.example.demo.model.Order order) {
// 方法1从订单描述中提取如果描述包含积分信息
if (order.getDescription() != null) {
java.util.regex.Pattern pattern = java.util.regex.Pattern.compile("(\\d+)积分");
java.util.regex.Matcher matcher = pattern.matcher(order.getDescription());
if (matcher.find()) {
return Integer.valueOf(matcher.group(1));
}
}
// 方法2从订单项中提取如果订单项名称包含积分信息
if (order.getOrderItems() != null && !order.getOrderItems().isEmpty()) {
for (com.example.demo.model.OrderItem item : order.getOrderItems()) {
if (item.getProductName() != null) {
java.util.regex.Pattern pattern = java.util.regex.Pattern.compile("(\\d+)积分");
java.util.regex.Matcher matcher = pattern.matcher(item.getProductName());
if (matcher.find()) {
return Integer.valueOf(matcher.group(1));
}
// 如果是会员订阅,根据订单金额计算积分
if (item.getProductName().contains("标准版") || item.getProductName().contains("专业版")) {
// 标准版200积分/月专业版1000积分/月
if (item.getProductName().contains("标准版")) {
return 200;
} else if (item.getProductName().contains("专业版")) {
return 1000;
}
}
}
}
}
// 方法3根据订单类型和金额估算
if (order.getOrderType() != null) {
if (order.getOrderType() == com.example.demo.model.OrderType.SUBSCRIPTION) {
// 订阅订单:根据金额估算积分
// 标准版:$59 = 200积分专业版$259 = 1000积分
if (order.getTotalAmount() != null) {
double amount = order.getTotalAmount().doubleValue();
if (amount >= 250) {
return 1000; // 专业版
} else if (amount >= 50) {
return 200; // 标准版
}
}
}
}
return null;
}
/**
* 从支付记录中提取积分数量
* 使用与 PaymentService.addPointsForPayment 相同的逻辑
*/
private Integer extractPointsFromPayment(com.example.demo.model.Payment payment) {
if (payment == null) {
return null;
}
java.math.BigDecimal amount = payment.getAmount();
if (amount == null) {
logger.warn("支付记录 ID: {} 的金额为空", payment.getId());
return null;
}
String description = payment.getDescription() != null ? payment.getDescription() : "";
Integer pointsToAdd = 0;
// 优先从描述中识别套餐类型
if (description.contains("标准版") || description.contains("standard") ||
description.contains("Standard") || description.contains("STANDARD")) {
// 标准版订阅 - 200积分
pointsToAdd = 200;
logger.debug("从描述识别为标准版,积分: 200");
} else if (description.contains("专业版") || description.contains("premium") ||
description.contains("Premium") || description.contains("PREMIUM")) {
// 专业版订阅 - 1000积分
pointsToAdd = 1000;
logger.debug("从描述识别为专业版,积分: 1000");
} else {
// 如果描述中没有套餐信息,根据金额判断
// 标准版订阅 (59-258元) - 200积分
if (amount.compareTo(new java.math.BigDecimal("59.00")) >= 0 &&
amount.compareTo(new java.math.BigDecimal("259.00")) < 0) {
pointsToAdd = 200;
logger.debug("根据金额 {} 判断为标准版,积分: 200", amount);
}
// 专业版订阅 (259元以上) - 1000积分
else if (amount.compareTo(new java.math.BigDecimal("259.00")) >= 0) {
pointsToAdd = 1000;
logger.debug("根据金额 {} 判断为专业版,积分: 1000", amount);
} else {
logger.debug("支付金额 {} 不在已知套餐范围内,不增加积分", amount);
}
}
return pointsToAdd > 0 ? pointsToAdd : null;
}
}

View File

@@ -15,10 +15,12 @@ import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
import com.example.demo.model.ImageToVideoTask;
import com.example.demo.model.StoryboardVideoTask;
import com.example.demo.model.TextToVideoTask;
import com.example.demo.model.User;
import com.example.demo.model.UserWork;
import com.example.demo.repository.ImageToVideoTaskRepository;
import com.example.demo.repository.StoryboardVideoTaskRepository;
import com.example.demo.repository.TextToVideoTaskRepository;
import com.example.demo.repository.UserWorkRepository;
@@ -43,33 +45,58 @@ public class UserWorkService {
@Autowired
private ImageToVideoTaskRepository imageToVideoTaskRepository;
@Autowired
private StoryboardVideoTaskRepository storyboardVideoTaskRepository;
/**
* 从任务创建作品
*/
@Transactional
public UserWork createWorkFromTask(String taskId, String resultUrl) {
// 检查是否已存在作品
// 检查是否已存在作品(使用同步检查,防止并发创建)
// 注意:这个检查不是原子的,但配合外部的悲观锁应该能防止大部分并发问题
Optional<UserWork> existingWork = userWorkRepository.findByTaskId(taskId);
if (existingWork.isPresent()) {
logger.warn("作品已存在,跳过创建: {}", taskId);
logger.info("作品已存在,跳过创建: taskId={}, workId={}", taskId, existingWork.get().getId());
return existingWork.get();
}
// 尝试从文生视频任务创建作品
Optional<TextToVideoTask> textTaskOpt = textToVideoTaskRepository.findByTaskId(taskId);
if (textTaskOpt.isPresent()) {
TextToVideoTask task = textTaskOpt.get();
return createTextToVideoWork(task, resultUrl);
}
try {
// 尝试从文生视频任务创建作品
Optional<TextToVideoTask> textTaskOpt = textToVideoTaskRepository.findByTaskId(taskId);
if (textTaskOpt.isPresent()) {
TextToVideoTask task = textTaskOpt.get();
return createTextToVideoWork(task, resultUrl);
}
// 尝试从图生视频任务创建作品
Optional<ImageToVideoTask> imageTaskOpt = imageToVideoTaskRepository.findByTaskId(taskId);
if (imageTaskOpt.isPresent()) {
ImageToVideoTask task = imageTaskOpt.get();
return createImageToVideoWork(task, resultUrl);
}
// 尝试从图生视频任务创建作品
Optional<ImageToVideoTask> imageTaskOpt = imageToVideoTaskRepository.findByTaskId(taskId);
if (imageTaskOpt.isPresent()) {
ImageToVideoTask task = imageTaskOpt.get();
return createImageToVideoWork(task, resultUrl);
}
throw new RuntimeException("找不到对应的任务: " + taskId);
// 尝试从分镜视频任务创建作品
Optional<StoryboardVideoTask> storyboardTaskOpt = storyboardVideoTaskRepository.findByTaskId(taskId);
if (storyboardTaskOpt.isPresent()) {
StoryboardVideoTask task = storyboardTaskOpt.get();
return createStoryboardVideoWork(task, resultUrl);
}
throw new RuntimeException("找不到对应的任务: " + taskId);
} catch (org.springframework.dao.DataIntegrityViolationException e) {
// 捕获数据库唯一约束违反异常如果task_id有唯一约束
// 或者捕获其他数据完整性异常
if (e.getMessage() != null && e.getMessage().contains("Duplicate entry")) {
logger.warn("作品可能已存在(数据库约束冲突),重新查询: {}", taskId);
// 重新查询,可能其他线程已经创建了
Optional<UserWork> retryWork = userWorkRepository.findByTaskId(taskId);
if (retryWork.isPresent()) {
return retryWork.get();
}
}
throw e;
}
}
/**
@@ -122,6 +149,31 @@ public class UserWorkService {
return work;
}
/**
* 创建分镜视频作品
*/
private UserWork createStoryboardVideoWork(StoryboardVideoTask task, String resultUrl) {
UserWork work = new UserWork();
work.setUserId(getUserIdByUsername(task.getUsername()));
work.setUsername(task.getUsername());
work.setTaskId(task.getTaskId());
work.setWorkType(UserWork.WorkType.STORYBOARD_VIDEO);
work.setTitle(generateTitle(task.getPrompt()));
work.setDescription("分镜视频作品");
work.setPrompt(task.getPrompt());
work.setResultUrl(resultUrl);
work.setDuration("10s"); // 分镜视频默认10秒
work.setAspectRatio(task.getAspectRatio());
work.setQuality(task.isHdMode() ? "HD" : "SD");
work.setPointsCost(task.getCostPoints());
work.setStatus(UserWork.WorkStatus.COMPLETED);
work.setCompletedAt(LocalDateTime.now());
work = userWorkRepository.save(work);
logger.info("创建分镜视频作品成功: {}, 用户: {}", work.getId(), work.getUsername());
return work;
}
/**
* 根据用户名获取用户ID
*/

View File

@@ -0,0 +1,398 @@
package com.example.demo.service;
import java.io.File;
import java.io.FileOutputStream;
import java.io.InputStream;
import java.net.HttpURLConnection;
import java.net.URI;
import java.net.URL;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.List;
import java.util.UUID;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;
/**
* 视频拼接服务
* 用于将多个视频文件拼接成一个视频
*/
@Service
public class VideoConcatService {
private static final Logger logger = LoggerFactory.getLogger(VideoConcatService.class);
@Value("${app.temp.dir:./temp}")
private String tempDir;
@Value("${app.ffmpeg.path:ffmpeg}")
private String ffmpegPath;
private Boolean ffmpegAvailable = null; // 缓存FFmpeg可用性检测结果
/**
* 规范化FFmpeg路径处理Windows路径中的反斜杠和换行符问题
*/
private String normalizeFFmpegPath(String path) {
if (path == null || path.isEmpty()) {
return "ffmpeg";
}
// 移除所有换行符、回车符和多余空格
String normalized = path.replaceAll("[\\r\\n]", "").trim();
// 将正斜杠转换为反斜杠Windows需要
normalized = normalized.replace("/", "\\");
// 移除多余的反斜杠
normalized = normalized.replaceAll("\\\\+", "\\\\");
return normalized;
}
/**
* 下载视频文件
* @param videoUrl 视频URL
* @param outputPath 输出文件路径
* @return 是否下载成功
*/
public boolean downloadVideo(String videoUrl, String outputPath) {
if (videoUrl == null || videoUrl.isEmpty()) {
logger.error("视频URL为空无法下载");
return false;
}
try {
// 创建临时目录
Path tempPath = Paths.get(tempDir);
if (!Files.exists(tempPath)) {
Files.createDirectories(tempPath);
}
// 下载视频
URI uri = new URI(videoUrl);
URL url = uri.toURL();
HttpURLConnection connection = (HttpURLConnection) url.openConnection();
connection.setConnectTimeout(30000); // 30秒连接超时
connection.setReadTimeout(60000); // 60秒读取超时
connection.setRequestMethod("GET");
connection.setRequestProperty("User-Agent", "Mozilla/5.0");
try (InputStream in = connection.getInputStream();
FileOutputStream out = new FileOutputStream(outputPath)) {
byte[] buffer = new byte[8192];
int bytesRead;
while ((bytesRead = in.read(buffer)) != -1) {
out.write(buffer, 0, bytesRead);
}
} finally {
connection.disconnect();
}
logger.info("视频下载成功: {} -> {}", videoUrl, outputPath);
return true;
} catch (Exception e) {
logger.error("下载视频失败: {}", videoUrl, e);
return false;
}
}
/**
* 检测FFmpeg是否可用公共方法可用于测试
* @return true如果FFmpeg可用
*/
public boolean isFFmpegAvailable() {
if (ffmpegAvailable != null) {
return ffmpegAvailable;
}
try {
// 规范化路径
String normalizedPath = normalizeFFmpegPath(ffmpegPath);
logger.debug("FFmpeg路径原始: {}", ffmpegPath);
logger.debug("FFmpeg路径规范化后: {}", normalizedPath);
ProcessBuilder processBuilder = new ProcessBuilder(normalizedPath, "-version");
processBuilder.redirectErrorStream(true);
Process process = processBuilder.start();
// 等待最多3秒
boolean finished = process.waitFor(3, java.util.concurrent.TimeUnit.SECONDS);
if (finished && process.exitValue() == 0) {
ffmpegAvailable = true;
logger.info("FFmpeg检测成功路径: {}", normalizedPath);
return true;
} else {
if (!finished) {
process.destroyForcibly();
}
ffmpegAvailable = false;
logger.warn("FFmpeg检测失败退出码: {}", finished ? process.exitValue() : "超时");
return false;
}
} catch (Exception e) {
ffmpegAvailable = false;
logger.warn("FFmpeg不可用: {}", e.getMessage());
return false;
}
}
/**
* 拼接多个视频文件
* 使用FFmpeg进行视频拼接必须安装FFmpeg
* 注意不再使用JCodec作为后备方案因为会导致内存溢出
* @param videoFiles 视频文件路径列表
* @param outputPath 输出文件路径
* @return 是否拼接成功
*/
public boolean concatVideos(List<String> videoFiles, String outputPath) {
if (videoFiles == null || videoFiles.isEmpty()) {
logger.error("视频文件列表为空,无法拼接");
return false;
}
// 首先检测FFmpeg是否可用
if (isFFmpegAvailable()) {
// 尝试使用FFmpeg更快不重新编码
if (tryFFmpegConcat(videoFiles, outputPath)) {
return true;
}
// 如果FFmpeg调用失败重置可用性标志以便下次重新检测
ffmpegAvailable = null;
logger.error("FFmpeg执行失败无法拼接视频");
} else {
logger.error("FFmpeg不可用无法拼接视频。请确保FFmpeg已正确安装并配置");
}
// 不再使用JCodec作为后备方案因为
// 1. JCodec需要将所有视频帧加载到内存会导致内存溢出
// 2. JCodec的拼接功能实现不完整无法正确输出视频
// 3. 对于多个视频文件,内存消耗会非常大(每个视频可能有数百帧)
logger.error("视频拼接失败FFmpeg不可用或执行失败。请检查FFmpeg安装和配置");
return false;
}
/**
* 尝试使用FFmpeg进行视频拼接快速不重新编码
*/
private boolean tryFFmpegConcat(List<String> videoFiles, String outputPath) {
try {
// 创建临时目录
Path tempPath = Paths.get(tempDir);
if (!Files.exists(tempPath)) {
Files.createDirectories(tempPath);
}
// 创建FFmpeg concat文件列表
String concatListFile = tempDir + File.separator + "concat_" + UUID.randomUUID().toString() + ".txt";
try (FileOutputStream fos = new FileOutputStream(concatListFile)) {
for (String videoFile : videoFiles) {
// FFmpeg concat格式file 'path/to/video.mp4'
// 使用相对路径相对于concat文件所在目录
// 由于所有视频文件都在同一个临时目录,相对路径就是文件名
File videoFileObj = new File(videoFile);
String relativePath = videoFileObj.getName(); // 直接使用文件名作为相对路径
// 将Windows路径中的反斜杠转换为正斜杠FFmpeg在Windows上也支持正斜杠
String normalizedFile = relativePath.replace("\\", "/");
String line = "file '" + normalizedFile + "'\n";
fos.write(line.getBytes("UTF-8"));
logger.debug("添加到concat列表: {} (原始文件: {})", normalizedFile, videoFile);
}
}
logger.info("FFmpeg concat文件已创建: {}, 包含 {} 个视频", concatListFile, videoFiles.size());
// 规范化FFmpeg路径
String normalizedPath = normalizeFFmpegPath(ffmpegPath);
// 输出路径保持原样(相对路径或绝对路径都可以)
String normalizedOutputPath = outputPath;
// concat文件路径相对于工作目录
File concatFileObj = new File(concatListFile);
String concatFileName = concatFileObj.getName();
// 执行FFmpeg拼接命令
ProcessBuilder processBuilder = new ProcessBuilder(
normalizedPath,
"-f", "concat",
"-safe", "0",
"-i", concatFileName, // 使用文件名(相对路径)
"-c", "copy",
"-y", // 覆盖输出文件
normalizedOutputPath
);
// 设置工作目录为临时目录这样concat文件和视频文件都在这里
// 使用规范化路径:如果是相对路径,会基于应用运行目录解析;如果是绝对路径,直接使用
File tempDirFile;
if (new File(tempDir).isAbsolute()) {
// 已经是绝对路径,直接使用
tempDirFile = new File(tempDir);
} else {
// 相对路径,转换为绝对路径(基于应用运行目录)
tempDirFile = new File(tempDir).getAbsoluteFile();
}
processBuilder.directory(tempDirFile);
logger.debug("FFmpeg工作目录: {}, concat文件: {}", tempDirFile.getAbsolutePath(), concatFileName);
processBuilder.redirectErrorStream(true);
Process process = processBuilder.start();
// 异步读取FFmpeg输出用于调试
StringBuilder output = new StringBuilder();
Thread outputReader = new Thread(() -> {
try (java.io.BufferedReader reader = new java.io.BufferedReader(
new java.io.InputStreamReader(process.getInputStream(), java.nio.charset.StandardCharsets.UTF_8))) {
String line;
while ((line = reader.readLine()) != null) {
synchronized (output) {
output.append(line).append("\n");
// 只记录前1000字符避免日志过长
if (output.length() > 1000) {
break;
}
}
}
} catch (Exception e) {
logger.debug("读取FFmpeg输出时出错: {}", e.getMessage());
}
});
outputReader.setDaemon(true);
outputReader.start();
// 等待进程完成最多等待10分钟因为6个视频拼接可能需要更长时间
boolean finished = process.waitFor(10, java.util.concurrent.TimeUnit.MINUTES);
// 等待输出读取线程完成最多1秒
try {
outputReader.join(1000);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
}
// 清理临时文件
try {
Files.deleteIfExists(Paths.get(concatListFile));
} catch (Exception e) {
logger.warn("删除临时文件失败: {}", concatListFile, e);
}
if (finished && process.exitValue() == 0) {
// 验证输出文件是否存在
File outputFile = new File(normalizedOutputPath);
if (outputFile.exists() && outputFile.length() > 0) {
logger.info("使用FFmpeg视频拼接成功: {} 个视频 -> {} (大小: {} bytes)",
videoFiles.size(), normalizedOutputPath, outputFile.length());
return true;
} else {
logger.error("FFmpeg执行成功但输出文件不存在或为空: {}", normalizedOutputPath);
return false;
}
} else {
String outputStr;
synchronized (output) {
outputStr = output.toString();
}
if (!finished) {
process.destroyForcibly();
logger.error("FFmpeg进程超时10分钟输出: {}",
outputStr.length() > 0 ? outputStr.substring(0, Math.min(500, outputStr.length())) : "无输出");
} else {
int exitCode = process.exitValue();
logger.error("FFmpeg退出码: {}, 输出: {}", exitCode,
outputStr.length() > 0 ? outputStr.substring(0, Math.min(500, outputStr.length())) : "无输出");
}
return false;
}
} catch (Exception e) {
logger.debug("FFmpeg不可用或执行失败: {}", e.getMessage());
return false;
}
}
/**
* 使用JCodec进行视频拼接已禁用
*
* 注意:此方法已被禁用,原因:
* 1. JCodec需要将所有视频帧加载到内存会导致OutOfMemoryError
* 2. 对于多个视频文件如6个视频内存消耗会非常大
* 3. JCodec的拼接功能实现不完整无法正确输出视频
*
* 解决方案必须使用FFmpeg进行视频拼接
*
* @param videoFiles 视频文件路径列表
* @param outputPath 输出文件路径
* @return 总是返回false已禁用
*/
@SuppressWarnings("unused")
private boolean tryJCodecConcat(List<String> videoFiles, String outputPath) {
logger.error("JCodec视频拼接已被禁用因为会导致内存溢出。请使用FFmpeg进行视频拼接");
return false;
}
/**
* 下载并拼接多个视频
* @param videoUrls 视频URL列表
* @param outputPath 输出文件路径
* @return 是否成功
*/
public boolean downloadAndConcatVideos(List<String> videoUrls, String outputPath) {
if (videoUrls == null || videoUrls.isEmpty()) {
logger.error("视频URL列表为空无法下载和拼接");
return false;
}
try {
// 创建临时目录
Path tempPath = Paths.get(tempDir);
if (!Files.exists(tempPath)) {
Files.createDirectories(tempPath);
}
// 下载所有视频
List<String> downloadedFiles = new java.util.ArrayList<>();
for (int i = 0; i < videoUrls.size(); i++) {
String videoUrl = videoUrls.get(i);
String tempFile = tempDir + File.separator + "video_" + UUID.randomUUID().toString() + ".mp4";
if (downloadVideo(videoUrl, tempFile)) {
downloadedFiles.add(tempFile);
logger.info("视频 {} 下载成功: {}", i + 1, tempFile);
} else {
logger.error("视频 {} 下载失败: {}", i + 1, videoUrl);
// 清理已下载的文件
for (String file : downloadedFiles) {
try {
Files.deleteIfExists(Paths.get(file));
} catch (Exception e) {
logger.warn("清理临时文件失败: {}", file, e);
}
}
return false;
}
}
// 拼接视频
boolean success = concatVideos(downloadedFiles, outputPath);
// 清理临时文件
for (String file : downloadedFiles) {
try {
Files.deleteIfExists(Paths.get(file));
} catch (Exception e) {
logger.warn("清理临时文件失败: {}", file, e);
}
}
return success;
} catch (Exception e) {
logger.error("下载并拼接视频失败", e);
return false;
}
}
}

View File

@@ -7,10 +7,12 @@ spring.profiles.active=dev
server.address=localhost
server.port=8080
# 文件上传配置
spring.servlet.multipart.max-file-size=10MB
spring.servlet.multipart.max-request-size=20MB
# 文件上传配置扩大请求体大小以支持大图片Base64编码
spring.servlet.multipart.max-file-size=500MB
spring.servlet.multipart.max-request-size=600MB
spring.servlet.multipart.enabled=true
# Tomcat 最大POST大小
server.tomcat.max-http-post-size=600MB
# 应用配置
app.upload.path=uploads

View File

@@ -33,3 +33,6 @@ CREATE TABLE IF NOT EXISTS task_queue (

View File

@@ -32,3 +32,6 @@ CREATE TABLE IF NOT EXISTS points_freeze_records (

View File

@@ -35,3 +35,6 @@ CREATE TABLE task_status (

View File

@@ -577,6 +577,9 @@

View File

@@ -493,6 +493,9 @@

View File

@@ -532,6 +532,9 @@