Skip to content

Commit 323b58f

Browse files
authored
[ITensors] Use task_local_storage for Index ID RNG (#1646)
1 parent 93dc9bd commit 323b58f

File tree

4 files changed

+7
-23
lines changed

4 files changed

+7
-23
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ITensors"
22
uuid = "9136182c-28ba-11e9-034c-db9fb085ebd5"
33
authors = ["Matthew Fishman <[email protected]>", "Miles Stoudenmire <[email protected]>"]
4-
version = "0.9.3"
4+
version = "0.9.4"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/ITensors.jl

-4
Original file line numberDiff line numberDiff line change
@@ -155,8 +155,4 @@ include("deprecated.jl")
155155
include("argsdict/argsdict.jl")
156156
include("packagecompile/compile.jl")
157157
include("developer_tools.jl")
158-
159-
function __init__()
160-
return resize!(empty!(INDEX_ID_RNGs), Threads.nthreads()) # ensures that we didn't save a bad object
161-
end
162158
end

src/index.jl

+3-15
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,14 @@
11
using NDTensors: NDTensors, sim
22
using .QuantumNumbers: QuantumNumbers, Arrow, In, Neither, Out
3+
using Random: Xoshiro
34
using .TagSets:
45
TagSets, TagSet, @ts_str, addtags, commontags, hastags, removetags, replacetags
56

67
#const IDType = UInt128
78
const IDType = UInt64
89

9-
# Custom RNG for Index id
10-
# Vector of RNGs, one for each thread
11-
const INDEX_ID_RNGs = MersenneTwister[]
12-
@inline index_id_rng() = index_id_rng(Threads.threadid())
13-
@noinline function index_id_rng(tid::Int)
14-
0 < tid <= length(INDEX_ID_RNGs) || _index_id_rng_length_assert()
15-
if @inbounds isassigned(INDEX_ID_RNGs, tid)
16-
@inbounds MT = INDEX_ID_RNGs[tid]
17-
else
18-
MT = MersenneTwister()
19-
@inbounds INDEX_ID_RNGs[tid] = MT
20-
end
21-
return MT
22-
end
23-
@noinline _index_id_rng_length_assert() = @assert false "0 < tid <= length(INDEX_ID_RNGs)"
10+
const _INDEX_ID_RNG_KEY = :ITensors_index_id_rng_bLeTZeEsme4bG3vD
11+
index_id_rng() = get!(task_local_storage(), _INDEX_ID_RNG_KEY, Xoshiro())::Xoshiro
2412

2513
"""
2614
An `Index` represents a single tensor index with fixed dimension `dim`. Copies of an Index compare equal unless their

test/ext/ITensorsTensorOperationsExt/runtests.jl

+3-3
Original file line numberDiff line numberDiff line change
@@ -150,10 +150,10 @@ end
150150
tmp * As[n]
151151
allocations_right_associative_pairwise += @allocated tmp = tmp * As[n]
152152
end
153-
@test allocations_right_associative_pairwise allocations_right_associative_1 rtol = 0.1
154-
@test allocations_right_associative_pairwise allocations_right_associative_2 rtol = 0.1
153+
@test allocations_right_associative_pairwise allocations_right_associative_1 rtol = 0.2
154+
@test allocations_right_associative_pairwise allocations_right_associative_2 rtol = 0.2
155155
@test allocations_right_associative_pairwise allocations_right_associative_3 rtol = 0.2
156-
@test allocations_right_associative_pairwise allocations_right_associative_4 rtol = 0.1
156+
@test allocations_right_associative_pairwise allocations_right_associative_4 rtol = 0.2
157157

158158
@test allocations_right_associative_1 < allocations_left_associative
159159
@test allocations_right_associative_2 < allocations_left_associative

0 commit comments

Comments
 (0)