-
Notifications
You must be signed in to change notification settings - Fork 332
Expand file tree
/
Copy pathkhQTTools.py
More file actions
4039 lines (3435 loc) · 164 KB
/
khQTTools.py
File metadata and controls
4039 lines (3435 loc) · 164 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# 多进程保护 - 防止在子进程中意外启动Qt应用
import sys
import os
# 检查是否在子进程中,只在子进程中设置环境变量
def is_subprocess():
"""检查是否在子进程中"""
import multiprocessing
try:
current_process = multiprocessing.current_process()
return current_process.name != 'MainProcess'
except:
return False
# 只在子进程中设置环境变量
if is_subprocess():
os.environ['QT_QPA_PLATFORM'] = 'offscreen'
os.environ['QT_LOGGING_RULES'] = 'qt.*=false'
import csv
import time
from datetime import datetime, timedelta
import pandas as pd
from xtquant import xtdata
# from xtquant.xtdata import get_client
import glob
import numpy as np
import logging
import ast
import holidays # 添加这个导入,用于处理holidays.China()
from typing import Dict, List, Union, Optional
import math
from khTrade import KhTradeManager
from types import SimpleNamespace
# 延迟导入Qt相关模块,避免在子进程中意外启动Qt应用
try:
if not is_subprocess():
# 在主进程中正常导入Qt模块
from PyQt5.QtCore import QThread, pyqtSignal
else:
# 在子进程中创建空的占位符类
class QThread:
def __init__(self):
pass
def start(self):
pass
def run(self):
pass
def pyqtSignal(*args, **kwargs):
return lambda: None
except ImportError:
# 如果导入失败,创建空的占位符类
class QThread:
def __init__(self):
pass
def start(self):
pass
def run(self):
pass
def pyqtSignal(*args, **kwargs):
return lambda: None
# ============================================================================
# 独立函数版本 - 可以直接调用,无需实例化类
# ============================================================================
# 初始化全局变量
_trading_periods = [
("093000", "113000"), # 上午
("130000", "150000") # 下午
]
_cn_holidays = holidays.China()
# 默认价格精度(股票为2位,ETF为3位)
_default_price_decimals = 2
def is_etf(stock_code: str) -> bool:
"""判断是否为ETF(不包括LOF)
Args:
stock_code: 股票代码,如 "510300.SH" 或 "159915.SZ"
Returns:
bool: 是否为ETF
说明:
上海ETF: 51(主流)、52(跨境)、53(部分)、55(债券)、56(新规)、58(科创)
深圳ETF: 159开头(深交所ETF统一为159开头)
注意:50/16开头是LOF,不是ETF
"""
# 去除后缀,取前6位数字
code = stock_code.split('.')[0]
# 上海ETF前缀
sh_etf_prefixes = ('51', '52', '53', '55', '56', '58')
# 深圳ETF前缀
sz_etf_prefix = '159'
return code.startswith(sh_etf_prefixes) or code.startswith(sz_etf_prefix)
def determine_pool_type(stock_list: List[str]) -> tuple:
"""判断股票池类型,返回类型和对应的价格精度
Args:
stock_list: 股票代码列表
Returns:
tuple: (pool_type, price_decimals)
pool_type: 'stock_only' | 'etf_only' | 'mixed'
price_decimals: 2(纯股票)或 3(含ETF或混合)
"""
if not stock_list:
return ('stock_only', 2)
has_stock = any(not is_etf(code) for code in stock_list)
has_etf = any(is_etf(code) for code in stock_list)
if has_stock and not has_etf:
# 纯股票池,使用2位小数
return ('stock_only', 2)
elif has_etf and not has_stock:
# 纯ETF池,使用3位小数
return ('etf_only', 3)
else:
# 混合池,使用3位小数
return ('mixed', 3)
# ==================== T+0交易模式相关函数 ====================
# 全局缓存T0 ETF列表,避免重复读取文件
_t0_etf_cache = None
def load_t0_etf_list() -> set:
"""加载T0型ETF列表
Returns:
set: T0型ETF的股票代码集合
"""
global _t0_etf_cache
if _t0_etf_cache is not None:
return _t0_etf_cache
_t0_etf_cache = set()
# 获取T0型ETF.csv文件路径
current_dir = os.path.dirname(os.path.abspath(__file__))
t0_file = os.path.join(current_dir, 'data', 'T0型ETF.csv')
if not os.path.exists(t0_file):
logging.warning(f"T0型ETF列表文件不存在: {t0_file}")
return _t0_etf_cache
try:
with open(t0_file, 'r', encoding='utf-8') as f:
reader = csv.reader(f)
for row in reader:
if row and len(row) >= 1:
stock_code = row[0].strip()
if stock_code:
_t0_etf_cache.add(stock_code)
logging.info(f"已加载 {len(_t0_etf_cache)} 只T0型ETF")
except Exception as e:
logging.error(f"加载T0型ETF列表失败: {e}")
return _t0_etf_cache
def is_t0_etf(stock_code: str) -> bool:
"""判断单个股票是否支持T+0交易
Args:
stock_code: 股票代码,如 '159001.SZ'
Returns:
bool: 是否支持T+0
"""
t0_list = load_t0_etf_list()
return stock_code in t0_list
def check_t0_support(stock_list: List[str]) -> tuple:
"""检验股票池的T+0支持情况
Args:
stock_list: 股票代码列表
Returns:
tuple: (support_type, is_t0_mode)
support_type: 'all_t0' | 'mixed' | 'no_t0'
is_t0_mode: True(全T+0)/ False(其他情况)
"""
if not stock_list:
return ('no_t0', False)
t0_list = load_t0_etf_list()
t0_count = sum(1 for code in stock_list if code in t0_list)
total_count = len(stock_list)
if t0_count == total_count:
# 全部是T+0 ETF
return ('all_t0', True)
elif t0_count > 0:
# 混合:部分支持T+0,部分不支持
return ('mixed', False)
else:
# 全部不支持T+0
return ('no_t0', False)
def get_t0_details(stock_list: List[str]) -> dict:
"""获取股票池中T+0支持的详细信息
Args:
stock_list: 股票代码列表
Returns:
dict: {
't0_stocks': List[str], # 支持T+0的股票
'non_t0_stocks': List[str], # 不支持T+0的股票
't0_count': int,
'total_count': int
}
"""
t0_list = load_t0_etf_list()
t0_stocks = [code for code in stock_list if code in t0_list]
non_t0_stocks = [code for code in stock_list if code not in t0_list]
return {
't0_stocks': t0_stocks,
'non_t0_stocks': non_t0_stocks,
't0_count': len(t0_stocks),
'total_count': len(stock_list)
}
# ==================== 价格精度相关函数 ====================
def get_price_decimals(data: Dict = None) -> int:
"""从数据字典中获取价格精度设置
Args:
data: 策略接收的数据对象,包含框架信息 __framework__
Returns:
int: 价格精度(小数位数),默认为2
"""
if data is None:
return _default_price_decimals
framework = data.get("__framework__", None)
if framework and hasattr(framework, 'price_decimals'):
return framework.price_decimals
return _default_price_decimals
def round_price(price: float, decimals: int = None, data: Dict = None) -> float:
"""根据精度设置对价格进行四舍五入
Args:
price: 原始价格
decimals: 精度(小数位数),如果为None则从data中获取
data: 策略接收的数据对象
Returns:
float: 四舍五入后的价格
"""
if decimals is None:
decimals = get_price_decimals(data)
return round(price, decimals)
def format_price(price: float, decimals: int = None, data: Dict = None) -> str:
"""根据精度设置格式化价格为字符串
Args:
price: 价格
decimals: 精度(小数位数),如果为None则从data中获取
data: 策略接收的数据对象
Returns:
str: 格式化后的价格字符串
"""
if decimals is None:
decimals = get_price_decimals(data)
return f"{price:.{decimals}f}"
def is_trade_time() -> bool:
"""判断是否为交易时间"""
current = time.strftime("%H%M%S")
for start, end in _trading_periods:
if start <= current <= end:
return True
return False
def is_trade_day(date_str: str = None) -> bool:
"""判断是否为交易日(工作日且非法定节假日)
Args:
date_str: 日期字符串,支持格式:
- "YYYY-MM-DD" (如: "2024-12-25")
- "YYYYMMDD" (如: "20241225")
- None (默认为当天)
Returns:
bool: 是否为交易日
"""
if date_str is None:
date_str = datetime.now().strftime("%Y-%m-%d")
# 标准化日期格式
try:
# 尝试解析不同的日期格式
date_obj = None
# 格式1: YYYY-MM-DD
if '-' in date_str and len(date_str) == 10:
date_obj = datetime.strptime(date_str, "%Y-%m-%d")
# 格式2: YYYYMMDD
elif date_str.isdigit() and len(date_str) == 8:
date_obj = datetime.strptime(date_str, "%Y%m%d")
else:
# 尝试其他可能的格式
for fmt in ["%Y-%m-%d", "%Y%m%d", "%Y/%m/%d"]:
try:
date_obj = datetime.strptime(date_str, fmt)
break
except ValueError:
continue
if date_obj is None:
raise ValueError(f"无法解析日期格式: {date_str}")
# 首先排除周末 (5代表周六, 6代表周日)
if date_obj.weekday() >= 5:
return False
# 使用holidays库判断是否为法定节假日
date_only = date_obj.date()
if date_only in _cn_holidays:
return False
# 非周末且非法定节假日,则视为交易日
return True
except Exception as e:
print(f"判断交易日异常: {str(e)}")
# 如果出现异常,尝试基本的日期解析
try:
# 再次尝试解析常见格式
date_obj = None
for fmt in ["%Y-%m-%d", "%Y%m%d"]:
try:
date_obj = datetime.strptime(date_str, fmt)
break
except ValueError:
continue
if date_obj is None:
print(f"无法解析日期格式: {date_str},默认按交易日处理")
return True
date_only = date_obj.date()
if date_only in _cn_holidays:
print(f"日期 {date_str} 是法定节假日({_cn_holidays.get(date_only)}),非交易日")
return False
# 如果是周末,非交易日
if date_obj.weekday() >= 5:
print(f"日期 {date_str} 是周末,非交易日")
return False
return True
except:
# 实在判断不出来,默认为交易日
print(f"无法确定 {date_str} 是否为交易日,默认按普通工作日处理")
return True
def get_trade_days_count(start_date: str, end_date: str) -> int:
"""计算指定日期范围内的交易日天数
Args:
start_date: 起始日期,格式为"YYYY-MM-DD"
end_date: 结束日期,格式为"YYYY-MM-DD"
Returns:
int: 交易日天数
"""
try:
# 解析起始和结束日期
start_dt = datetime.strptime(start_date, "%Y-%m-%d")
end_dt = datetime.strptime(end_date, "%Y-%m-%d")
# 确保开始日期不晚于结束日期
if start_dt > end_dt:
logging.error(f"起始日期 {start_date} 晚于结束日期 {end_date}")
return 0
# 初始化计数器
trade_days = 0
# 遍历日期范围内的每一天
current_dt = start_dt
while current_dt <= end_dt:
current_date_str = current_dt.strftime("%Y-%m-%d")
# 使用is_trade_day函数判断是否为交易日
if is_trade_day(current_date_str):
trade_days += 1
# 前进到下一天
current_dt += timedelta(days=1)
logging.info(f"从 {start_date} 到 {end_date} 共有 {trade_days} 个交易日")
return trade_days
except Exception as e:
logging.error(f"计算交易日天数时出错: {str(e)}")
return 0
# ============================================================================
# 兼容性:保留原有的KhQuTools类,但让类方法调用上面的独立函数
# ============================================================================
class KhQuTools:
"""量化工具类(兼容性保留,推荐直接使用模块级函数)"""
def __init__(self):
# 为了兼容性保留这些属性,但实际会使用模块级函数
self.trading_periods = _trading_periods
self.cn_holidays = _cn_holidays
def is_trade_time(self) -> bool:
"""判断是否为交易时间(调用模块级函数)"""
return is_trade_time()
def is_trade_day(self, date_str: str = None) -> bool:
"""判断是否为交易日(调用模块级函数)"""
return is_trade_day(date_str)
def get_trade_days_count(self, start_date: str, end_date: str) -> int:
"""计算指定日期范围内的交易日天数(调用模块级函数)"""
return get_trade_days_count(start_date, end_date)
def calculate_moving_average(self, stock_code: str, period: int, field: str = 'close', fre_step: str = '1d', end_time: Optional[str] = None, fq: str = 'pre') -> float:
"""计算移动平均线
Args:
stock_code: 股票代码
period: 周期长度
field: 计算字段,默认为'close'
fre_step: 时间频率,如'1d', '1m'等
end_time: 结束时间,如果为None使用当前时间
fq: 复权方式,'pre'前复权, 'post'后复权, 'none'不复权
Returns:
float: 移动平均值
Raises:
ValueError: 如果不在交易时间(日内频率)或数据不足
"""
from datetime import datetime
if end_time is None:
now = datetime.now()
if fre_step in ['1m', '5m', 'tick']:
end_time = now.strftime('%Y%m%d %H%M%S')
else:
end_time = now.strftime('%Y%m%d')
# 结合 is_trade_time 判断(仅对日内频率)
if fre_step in ['1m', '5m', 'tick'] and not self.is_trade_time():
raise ValueError("不在交易时间内,无法计算日内移动平均线")
# 获取历史数据(不包含当前时间点)
data = khHistory(
symbol_list=stock_code,
fields=[field],
bar_count=period,
fre_step=fre_step,
current_time=end_time,
fq=fq,
force_download=True # 确保数据最新
)
if stock_code not in data or len(data[stock_code]) < period:
raise ValueError(f"股票 {stock_code} 数据量不足 {period} 条,无法计算 MA{period}")
prices = data[stock_code][field]
# 使用动态精度(根据是否为ETF判断)
decimals = 3 if is_etf(stock_code) else 2
return round(prices.mean(), decimals)
def khMA(stock_code: str, period: int, field: str = 'close', fre_step: str = '1d', end_time: Optional[str] = None, fq: str = 'pre', data: Dict = None) -> float:
"""计算移动平均线(独立函数版本)
Args:
stock_code: 股票代码
period: 周期长度
field: 计算字段,默认为'close'
fre_step: 时间频率,如'1d', '1m'等
end_time: 结束时间,如果为None使用当前时间
fq: 复权方式,'pre'前复权, 'post'后复权, 'none'不复权
data: 策略接收的数据对象,用于获取精度设置(可选)
Returns:
float: 移动平均值
Raises:
ValueError: 如果不在交易时间(日内频率)或数据不足
"""
from datetime import datetime
if end_time is None:
now = datetime.now()
if fre_step in ['1m', '5m', 'tick']:
end_time = now.strftime('%Y%m%d %H%M%S')
else:
end_time = now.strftime('%Y%m%d')
# 结合 is_trade_time 判断(仅对日内频率)
tools = KhQuTools()
if fre_step in ['1m', '5m', 'tick'] and not tools.is_trade_time():
raise ValueError("不在交易时间内,无法计算日内移动平均线")
# 获取历史数据(不包含当前时间点)
history_data = khHistory(
symbol_list=stock_code,
fields=[field],
bar_count=period,
fre_step=fre_step,
current_time=end_time,
fq=fq,
force_download=False # 不强制下载数据,提高回测速度
)
if stock_code not in history_data or len(history_data[stock_code]) < period:
raise ValueError(f"股票 {stock_code} 数据量不足 {period} 条,无法计算均线{period}")
prices = history_data[stock_code][field]
# 优先从传入的data中获取精度设置,否则根据股票代码判断
decimals = get_price_decimals(data) if data else (3 if is_etf(stock_code) else 2)
return round(prices.mean(), decimals)
def calculate_max_buy_volume(data: Dict, stock_code: str, price: float, cash_ratio: float = 1.0) -> int:
"""
计算最大可买入数量,考虑交易成本(包括滑点)
Args:
data: 策略接收的数据对象,包含账户信息 __account__ 和框架信息 __framework__
stock_code: 股票代码
price: 当前价格
cash_ratio: 使用可用资金的比例,默认为1.0表示使用全部可用资金
Returns:
int: 最大可买入股数(按手取整)
"""
try:
# 导入交易管理类
from khTrade import KhTradeManager
# 获取账户信息
account_info = data.get("__account__", {})
if not account_info:
logging.warning("无法获取账户信息,无法计算最大买入量")
return 0
# 获取资金信息
available_cash = account_info.get("cash", 0.0)
# 计算可用的资金
usable_cash = available_cash * cash_ratio
# 防止价格为0导致除零错误
if price <= 0:
logging.warning(f"股票 {stock_code} 价格异常: {price},无法计算买入量")
return 0
# 获取价格精度设置
decimals = get_price_decimals(data)
# 对价格进行四舍五入处理
price = round(price, decimals)
# 获取框架对象
framework = data.get("__framework__", None)
# 获取配置对象
if framework and hasattr(framework, 'config'):
config = framework.config
else:
logging.warning("未从数据字典中获取到框架对象或框架配置不可用,将使用默认交易成本设置")
config = SimpleNamespace(config_dict={"backtest": {"trade_cost": {}}})
# 创建交易管理器实例(使用实际配置)
trade_manager = KhTradeManager(config)
# 获取交易成本参数
commission_rate = trade_manager.commission_rate
transfer_fee_rate = 0.00001 if stock_code.startswith("sh.") else 0.0 # 沪市股票才有过户费
# 估算最大股数 (向下取整到100的倍数)
# 使用更精确的初始估算方式
estimated_shares = math.floor(usable_cash / price / (1 + commission_rate + transfer_fee_rate))
shares = math.floor(estimated_shares / 100) * 100
# 如果估算股数小于100,则无法买入
if shares < 100:
return 0
# 逐步减少股数,使用calculate_trade_cost精确计算成本
while shares >= 100:
# 使用calculate_trade_cost计算实际交易成本(包括滑点)
actual_price, trade_cost = trade_manager.calculate_trade_cost(
price=price,
volume=shares,
direction="buy",
stock_code=stock_code
)
# 计算总花费(实际价格 * 数量 + 交易成本)
total_cost = actual_price * shares + trade_cost
if total_cost <= usable_cash:
logging.info(f"计算买入量: 股票={stock_code}, 原始价格={price:.{decimals}f}, 考虑滑点后价格={actual_price:.{decimals}f}, "
f"可用现金={available_cash:.{decimals}f}, 使用比例={cash_ratio:.2f}, "
f"计划买入={shares}, 成本={trade_cost:.2f}, 总花费={total_cost:.{decimals}f}")
return int(shares) # 确保返回整数
shares -= 100 # 减少一手
return 0 # 循环结束仍未找到合适的买入量
except Exception as e:
logging.error(f"计算最大可买入数量时出错: {str(e)}", exc_info=True)
return 0
def generate_signal(data: Dict, stock_code: str, price: float, ratio: float, action: str, reason: str = "") -> List[Dict]:
"""
生成标准交易信号
Args:
data: 包含时间、账户、持仓信息的字典,以及框架信息 __framework__
stock_code: 股票代码
price: 交易价格
ratio: 当ratio≤1时表示交易比例(买入时指占剩余现金比例,卖出时指占可卖持仓比例)
当ratio>1时表示买入的股数(必须是100的整数倍)
action: 'buy' 或 'sell'
reason: 交易原因
Returns:
List[Dict]: 包含单个信号的列表,或空列表
"""
signals = []
current_time = data.get("__current_time__", {})
timestamp = current_time.get("timestamp")
# 获取价格精度设置
decimals = get_price_decimals(data)
# 对价格进行四舍五入处理
price = round(price, decimals)
if action == "buy":
# 判断ratio是否大于1,若大于1则表示买入股数
if ratio > 1:
# 检查股数是否为整百
target_volume = int(ratio)
if target_volume % 100 != 0:
error_msg = f"买入股数必须是100的整数倍: 股票={stock_code}, 输入股数={target_volume}"
logging.error(error_msg)
return []
# 计算最大可买入量进行验证
max_volume = calculate_max_buy_volume(data, stock_code, price, cash_ratio=1.0)
if max_volume == 0:
logging.warning(f"无法生成买入信号: 股票={stock_code}, 价格={price:.{decimals}f}, 目标股数={target_volume}, 但资金不足无法买入")
return []
elif target_volume > max_volume:
logging.warning(f"目标买入量超过最大可买入量: 股票={stock_code}, 目标={target_volume}, 最大可买={max_volume}, 将调整为最大可买入量")
actual_volume = max_volume
else:
actual_volume = target_volume
signal = {
"code": stock_code,
"action": "buy",
"price": price, # 价格已在函数开始时四舍五入
"volume": actual_volume,
"reason": reason or f"按价格 {price:.{decimals}f} 买入 {actual_volume}股({actual_volume//100}手)"
}
if timestamp:
signal["timestamp"] = timestamp
signals.append(signal)
logging.info(f"生成买入信号: {signal}")
else:
# ratio <= 1时按照资金比例计算可买入股数
max_volume = calculate_max_buy_volume(data, stock_code, price, cash_ratio=ratio)
if max_volume > 0:
signal = {
"code": stock_code,
"action": "buy",
"price": price, # 价格已在函数开始时四舍五入
"volume": max_volume,
"reason": reason or f"按价格 {price:.{decimals}f} 以 {ratio*100:.0f}% 资金比例买入"
}
if timestamp:
signal["timestamp"] = timestamp
signals.append(signal)
logging.info(f"生成买入信号: {signal}")
else:
logging.warning(f"无法生成买入信号: 股票={stock_code}, 价格={price:.{decimals}f}, 资金比例={ratio:.2f}, 计算可买量为0")
elif action == "sell":
positions_info = data.get("__positions__", {})
if stock_code in positions_info:
# 获取可卖数量,优先使用 'can_use_volume',否则用 'volume'
position_data = positions_info[stock_code]
available_volume = position_data.get("can_use_volume", position_data.get("volume", 0))
if available_volume > 0:
# 计算要卖出的股数 (向下取整到100的倍数)
sell_volume = math.floor((available_volume * ratio) / 100) * 100
if sell_volume > 0:
signal = {
"code": stock_code,
"action": "sell",
"price": price, # 价格已在函数开始时四舍五入
"volume": int(sell_volume), # 确保是整数
"reason": reason or f"按价格 {price:.{decimals}f} 卖出 {ratio*100:.0f}% 可用持仓"
}
if timestamp:
signal["timestamp"] = timestamp
signals.append(signal)
logging.info(f"生成卖出信号: {signal}")
else:
logging.warning(f"无法生成卖出信号: 股票={stock_code}, 价格={price:.{decimals}f}, 持仓比例={ratio:.2f}, 计算可卖量为0 (可用持仓={available_volume})")
else:
logging.warning(f"无法生成卖出信号: 股票={stock_code} 无可用持仓")
else:
logging.warning(f"无法生成卖出信号: 股票={stock_code} 不在持仓中")
return signals
def read_stock_csv(file_path):
"""
读取股票CSV文件,支持多种编码格式,并进行错误处理。
参数:
- file_path: CSV文件路径
返回:
- tuple: (股票代码列表, 股票名称列表)
"""
if not os.path.isfile(file_path):
raise FileNotFoundError(f"文件不存在: {file_path}")
# 尝试的编码列表
encodings = ['utf-8', 'gb18030', 'gbk', 'gb2312', 'utf-16', 'ascii']
# 存储结果
stock_codes = []
stock_names = []
# 尝试不同的编码
for encoding in encodings:
try:
with open(file_path, 'r', encoding=encoding) as file:
# 先读取少量内容来验证编码是否正确
file.read(1024)
file.seek(0) # 重置文件指针到开始
csv_reader = csv.reader(file)
# 检查是否有BOM
first_row = next(csv_reader)
if first_row and first_row[0].startswith('\ufeff'):
first_row[0] = first_row[0][1:] # 删除BOM
# 处理第一行
process_row(first_row, stock_codes, stock_names)
# 处理剩余行
for row in csv_reader:
process_row(row, stock_codes, stock_names)
# 如果成功读取,跳出循环
break
except UnicodeDecodeError:
# 如果是最后一个编码仍然失败,则抛出异常
if encoding == encodings[-1]:
raise Exception(f"无法读取文件 {file_path},已尝试以下编码:{', '.join(encodings)}")
continue
except Exception as e:
# 处理其他可能的异常
raise Exception(f"读取文件 {file_path} 时发生错误: {str(e)}")
return stock_codes, stock_names
def process_row(row, stock_codes, stock_names):
"""
处理CSV的单行数据,处理带有交易所后缀的股票代码
参数:
- row: CSV行数据
- stock_codes: 股票代码列表(会被修改)
- stock_names: 股票名称列表(会被修改)
"""
if len(row) >= 2:
stock_code = row[0].strip()
stock_name = row[1].strip()
logging.info(f"处理股票: {stock_code} - {stock_name}")
# 检查股票代码格式 - 简化筛选,只要有交易所后缀就接受
if '.' in stock_code: # 已经包含后缀
# 接受所有标准格式的证券代码(包括股票、ETF、指数、可转债等)
if stock_code.endswith(('.SH', '.SZ', '.BJ')): # 支持上海、深圳、北交所
stock_codes.append(stock_code)
stock_names.append(stock_name)
logging.info(f"添加证券: {stock_code} - {stock_name}")
else:
logging.info(f"跳过证券(交易所代码不支持): {stock_code}")
else:
logging.info(f"跳过证券(无交易所后缀): {stock_code}")
def download_and_store_data(local_data_path, stock_files, field_list, period_type, start_date, end_date, dividend_type='none', time_range='all', progress_callback=None, log_callback=None, check_interrupt=None):
"""
下载并存储指定股票、字段、周期类型和时间段的数据到文件。
函数功能:
1. 从指定的股票代码列表文件中读取股票代码。
2. 创建本地数据存储目录(如果不存在)。
3. 对于每只股票:
- 下载指定周期类型的数据到本地。
- 从本地读取指定字段的数据。
- 将 "time" 列转换为日期时间格式。
- 如果指定了时间段,则筛选出指定时间段内的数据。
- 将筛选出的数据添加到结果 DataFrame 中。
- 将结果 DataFrame 存储到本地文件。
4. 输出数据读取和存储完成的提示信息。
文件命名规则:
- 存储的文件名格式: "{股票代码}_{周期类型}_{起始日期}_{结束日期}_{时间段}_{复权方式}.csv"
- 示例1: "000001.SZ_tick_20240101_20240430_all_none.csv"
- 股票代码: 000001.SZ
- 周期类型: tick
- 起始日期: 20240101
- 结束日期: 20240430
- 时间段: all (表示全部时间段)
- 复权方式: none (表示不复权)
- 示例2: "000001.SZ_1d_20240101_20240430_all_front.csv"
- 复权方式: front (表示前复权)
- 如果指定了具体的时间段,时间段部分将替换为 "HH_MM-HH_MM" 的格式
- 示例: "000001.SZ_1m_20240101_20240430_09_30-11_30_none.csv"
- 时间段: 09_30-11_30 (表示 09:30 到 11:30 的时间段)
参数:
- local_data_path (str): 本地数据存储路径。
- 该参数指定存储数据的本地目录路径。
- 如果目录不存在,会自动创建。
- 示例: "I:/stock_data_all_2"
- stock_files (list): 股票代码列表文件路径列表。
- 该参数指定包含股票代码的文件路径列表。
- 每个文件应包含股票代码和名称两列。
- 支持的股票类型:
- A股:上海(600/601/603/605/688)、深圳(000/002/300/301)
- 指数:上证(000)、深证(399)
- 示例: ["HS300idx.csv", "otheridx.csv"]
- field_list (list): 要存储的字段列表。
- 该参数指定要下载和存储的股票数据字段列表。
- 常用字段包括:open(开盘价)、high(最高价)、low(最低价)、close(收盘价)、
volume(成交量)、amount(成交额)等。
- 示例: ["open", "high", "low", "close", "volume"]
- period_type (str): 要读取的周期类型。
- 该参数指定要下载和存储的数据周期类型。
- 可选值:
- 'tick': 逐笔数据
- '1m': 1分钟线
- '5m': 5分钟线
- '1d': 日线数据
- 示例: "1d"
- start_date (str): 起始日期。
- 该参数指定数据的起始日期。
- 格式为 "YYYYMMDD"。
- 示例: "20240101"
- end_date (str): 结束日期。
- 该参数指定数据的结束日期。
- 格式为 "YYYYMMDD"。
- 示例: "20240430"
- dividend_type (str, optional): 复权方式。
- 该参数指定数据的复权方式,默认为'none'。
- 可选值:
- 'none': 不复权,使用原始价格
- 'front': 前复权,基于最新价格进行前复权计算
- 'back': 后复权,基于首日价格进行后复权计算
- 'front_ratio': 等比前复权,基于最新价格进行等比前复权计算
- 'back_ratio': 等比后复权,基于首日价格进行等比后复权计算
- 注意:复权设置仅对股票价格数据有效,对指数和成交量等数据无影响
- 示例: "front"
- time_range (str, optional): 要读取的时间段。
- 该参数指定要筛选的数据时间段。
- 格式为 "HH:MM-HH:MM"。
- 如果指定为 "all",则不进行时间段筛选,保留全部时间段的数据。
- 仅对分钟和tick级别数据有效。
- 示例: "09:30-11:30" 或 "all"
- progress_callback (function, optional): 进度回调函数。
- 该函数用于更新下载进度。
- 接受一个整数参数,表示完成百分比(0-100)。
- 可用于更新GUI进度条等。
- log_callback (function, optional): 日志回调函数。
- 该函数用于记录处理过程中的日志信息。
- 接受一个字符串参数,表示日志消息。
- 可用于在GUI中显示处理状态等。
- check_interrupt (function, optional): 中断检查函数。
- 该函数用于检查是否需要中断下载过程。
- 返回True表示需要中断,返回False表示继续执行。
返回值:
- 无返回值,数据直接保存到指定目录。
异常:
- 如果股票代码文件不存在或格式错误,会记录警告并跳过。
- 如果数据下载失败,会记录错误并继续处理下一只股票。
- 如果保存文件失败,会记录错误信息。
- 如果中断检查函数返回True,会抛出InterruptedError异常。
"""
try:
# 获取所有股票代码
stocks = []
for stock_file in stock_files:
# 检查是否需要中断
if check_interrupt and check_interrupt():
logging.info("下载过程被中断")
raise InterruptedError("下载过程被用户中断")
if os.path.exists(stock_file):
logging.info(f"读取股票文件: {stock_file}")
codes, names = read_stock_csv(stock_file)
stocks.extend(codes)
logging.info(f"股票列表: {stocks}")
if not os.path.exists(local_data_path):
os.makedirs(local_data_path)
total_stocks = len(stocks)
for index, stock in enumerate(stocks, 1):
try:
# 检查是否需要中断
if check_interrupt and check_interrupt():
logging.info("下载过程被中断")
raise InterruptedError("下载过程被用户中断")
if log_callback:
log_callback(f"正在处理 {stock} ({index}/{total_stocks})")
# 判断是否为指数
is_index = stock in ["000001.SH", "399001.SZ", "399006.SZ", "000688.SH",
"000300.SH", "000905.SH", "000852.SH"]
try:
# 每次主要操作前检查中断
if check_interrupt and check_interrupt():
logging.info("下载过程被中断")
raise InterruptedError("下载过程被用户中断")
if is_index:
# 指数数据处理
logging.info(f"获取指数数据: {stock}")
xtdata.download_history_data(stock, period=period_type,
start_time=start_date, end_time=end_date)
# 再次检查中断
if check_interrupt and check_interrupt():
logging.info("下载过程被中断")
raise InterruptedError("下载过程被用户中断")
data = xtdata.get_market_data_ex(
field_list=['time'] + field_list,
stock_list=[stock],
period=period_type,
start_time=start_date,
end_time=end_date,
count=-1,
dividend_type=dividend_type, # 添加复权参数
fill_data=True
)
if data and stock in data:
df = data[stock]
logging.info(f"成功获取指数数据: {stock}")
else: