Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/abstract_varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -595,7 +595,7 @@ OrderedDict{VarName{sym, typeof(identity)} where sym, Float64} with 2 entries:
m => 2.0

julia> values_as(vi, Vector)
2-element Vector{Real}:
2-element Vector{Float64}:
1.0
2.0
```
Expand Down
81 changes: 57 additions & 24 deletions src/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ the left-hand side of tilde statements. For example, `x[1]` and `x[2]` both
have the same symbol `x`.

Several type aliases are provided for these forms of VarInfos:
- `VarInfo{<:Metadata}` is `UntypedVarInfo`
- `VarInfo{<:Metadata}` is `UntypedLegacyVarInfo`
- `VarInfo{<:VarNamedVector}` is `UntypedVectorVarInfo`
- `VarInfo{<:NamedTuple}` is `NTVarInfo`

Expand All @@ -107,7 +107,7 @@ struct VarInfo{Tmeta,Accs<:AccumulatorTuple} <: AbstractVarInfo
metadata::Tmeta
accs::Accs
end
function VarInfo(meta=Metadata())
function VarInfo(meta=VarNamedVector())
return VarInfo(meta, default_accumulators())
end

Expand Down Expand Up @@ -143,7 +143,7 @@ function VarInfo(model::Model, init_strategy::AbstractInitStrategy=InitFromPrior
end

const UntypedVectorVarInfo = VarInfo{<:VarNamedVector}
const UntypedVarInfo = VarInfo{<:Metadata}
const UntypedLegacyVarInfo = VarInfo{<:Metadata}
# TODO: NTVarInfo carries no information about the type of the actual metadata
# i.e. the elements of the NamedTuple. It could be Metadata or it could be
# VarNamedVector.
Expand All @@ -154,6 +154,7 @@ const NTVarInfo = VarInfo{<:NamedTuple}
const VarInfoOrThreadSafeVarInfo{Tmeta} = Union{
VarInfo{Tmeta},ThreadSafeVarInfo{<:VarInfo{Tmeta}}
}
const UntypedVarInfo = UntypedVectorVarInfo

function Base.:(==)(vi1::VarInfo, vi2::VarInfo)
return (vi1.metadata == vi2.metadata && vi1.accs == vi2.accs)
Expand Down Expand Up @@ -194,8 +195,20 @@ end
# VarInfo constructors #
########################

function untyped_varinfo(
rng::Random.AbstractRNG,
model::Model,
init_strategy::AbstractInitStrategy=InitFromPrior(),
)
return untyped_vector_varinfo(rng, model, init_strategy)
end

function untyped_varinfo(model::Model, init_strategy::AbstractInitStrategy=InitFromPrior())
return untyped_vector_varinfo(Random.default_rng(), model, init_strategy)
end

"""
untyped_varinfo([rng, ]model[, init_strategy])
untyped_legacy_varinfo([rng, ]model[, init_strategy])

Construct a VarInfo object for the given `model`, which has just a single
`Metadata` as its metadata field.
Expand All @@ -205,27 +218,29 @@ Construct a VarInfo object for the given `model`, which has just a single
- `model::Model`: The model for which to create the varinfo object
- `init_strategy::AbstractInitStrategy`: How the values are to be initialised. Defaults to `InitFromPrior()`.
"""
function untyped_varinfo(
function untyped_legacy_varinfo(
rng::Random.AbstractRNG,
model::Model,
init_strategy::AbstractInitStrategy=InitFromPrior(),
)
return last(init!!(rng, model, VarInfo(Metadata()), init_strategy))
end
function untyped_varinfo(model::Model, init_strategy::AbstractInitStrategy=InitFromPrior())
return untyped_varinfo(Random.default_rng(), model, init_strategy)
function untyped_legacy_varinfo(
model::Model, init_strategy::AbstractInitStrategy=InitFromPrior()
)
return untyped_legacy_varinfo(Random.default_rng(), model, init_strategy)
end

"""
typed_varinfo(vi::UntypedVarInfo)
typed_legacy_varinfo(vi::UntypedLegacyVarInfo)

This function finds all the unique `sym`s from the instances of `VarName{sym}` found in
`vi.metadata.vns`. It then extracts the metadata associated with each symbol from the
global `vi.metadata` field. Finally, a new `VarInfo` is created with a new `metadata` as
a `NamedTuple` mapping from symbols to type-stable `Metadata` instances, one for each
symbol.
"""
function typed_varinfo(vi::UntypedVarInfo)
function typed_legacy_varinfo(vi::UntypedLegacyVarInfo)
meta = vi.metadata
new_metas = Metadata[]
# Symbols of all instances of `VarName{sym}` in `vi.vns`
Expand Down Expand Up @@ -289,12 +304,16 @@ function typed_varinfo(
model::Model,
init_strategy::AbstractInitStrategy=InitFromPrior(),
)
return typed_varinfo(untyped_varinfo(rng, model, init_strategy))
return typed_vector_varinfo(rng, model, init_strategy)
end
function typed_varinfo(model::Model, init_strategy::AbstractInitStrategy=InitFromPrior())
return typed_varinfo(Random.default_rng(), model, init_strategy)
end

function typed_varinfo(vi::UntypedVectorVarInfo)
return typed_vector_varinfo(vi)
end

"""
untyped_vector_varinfo([rng, ]model[, init_strategy])

Expand All @@ -306,7 +325,7 @@ Return a VarInfo object for the given `model`, which has just a single
- `model::Model`: The model for which to create the varinfo object
- `init_strategy::AbstractInitStrategy`: How the values are to be initialised. Defaults to `InitFromPrior()`.
"""
function untyped_vector_varinfo(vi::UntypedVarInfo)
function untyped_vector_varinfo(vi::UntypedLegacyVarInfo)
md = metadata_to_varnamedvector(vi.metadata)
return VarInfo(md, copy(vi.accs))
end
Expand Down Expand Up @@ -626,11 +645,11 @@ end
const VarView = Union{Int,UnitRange,Vector{Int}}

"""
setval!(vi::UntypedVarInfo, val, vview::Union{Int, UnitRange, Vector{Int}})
setval!(vi::UntypedLegacyVarInfo, val, vview::Union{Int, UnitRange, Vector{Int}})

Set the value of `vi.vals[vview]` to `val`.
"""
setval!(vi::UntypedVarInfo, val, vview::VarView) = vi.metadata.vals[vview] = val
setval!(vi::UntypedLegacyVarInfo, val, vview::VarView) = vi.metadata.vals[vview] = val

"""
getmetadata(vi::VarInfo, vn::VarName)
Expand Down Expand Up @@ -825,10 +844,10 @@ set_transformed!!(vi::VarInfo, ::AbstractTransformation) = set_transformed!!(vi,

Returns a tuple of the unique symbols of random variables in `vi`.
"""
syms(vi::UntypedVarInfo) = Tuple(unique!(map(getsym, vi.metadata.vns))) # get all symbols
syms(vi::UntypedLegacyVarInfo) = Tuple(unique!(map(getsym, vi.metadata.vns))) # get all symbols
syms(vi::NTVarInfo) = keys(vi.metadata)

_getidcs(vi::UntypedVarInfo) = 1:length(vi.metadata.idcs)
_getidcs(vi::UntypedLegacyVarInfo) = 1:length(vi.metadata.idcs)
_getidcs(vi::NTVarInfo) = _getidcs(vi.metadata)

@generated function _getidcs(metadata::NamedTuple{names}) where {names}
Expand Down Expand Up @@ -949,7 +968,7 @@ function link!!(
return Accessors.@set vi.varinfo = DynamicPPL.link!!(t, vi.varinfo, vns, model)
end

function _link!!(vi::UntypedVarInfo, vns)
function _link!!(vi::UntypedLegacyVarInfo, vns)
# TODO: Change to a lazy iterator over `vns`
if ~is_transformed(vi, vns[1])
for vn in vns
Expand Down Expand Up @@ -1063,7 +1082,7 @@ function maybe_invlink_before_eval!!(vi::VarInfo, model::Model)
return maybe_invlink_before_eval!!(t, vi, model)
end

function _invlink!!(vi::UntypedVarInfo, vns)
function _invlink!!(vi::UntypedLegacyVarInfo, vns)
if is_transformed(vi, vns[1])
for vn in vns
f = linked_internal_to_internal_transform(vi, vn)
Expand Down Expand Up @@ -1297,6 +1316,10 @@ function _link_metadata!!(
metadata = setindex_internal!!(metadata, val_new, vn, transform_from_linked)
set_transformed!(metadata, true, vn)
end
# Linking can often change the sizes of variables, causing inactive elements. We don't
# want to keep them around, since typically linking is done once and then the VarInfo
# is evaluated multiple times. Hence we contiguify here.
metadata = contiguify!(metadata)
return metadata, cumulative_logjac
end

Expand Down Expand Up @@ -1465,11 +1488,15 @@ function _invlink_metadata!!(
metadata = setindex_internal!!(metadata, tovec(new_val), vn, new_transform)
set_transformed!(metadata, false, vn)
end
# Linking can often change the sizes of variables, causing inactive elements. We don't
# want to keep them around, since typically linking is done once and then the VarInfo
# is evaluated multiple times. Hence we contiguify here.
metadata = contiguify!(metadata)
return metadata, cumulative_inv_logjac
end

# TODO(mhauru) The treatment of the case when some variables are transformed and others are
# not should be revised. It used to be the case that for UntypedVarInfo `is_transformed`
# not should be revised. It used to be the case that for UntypedLegacyVarInfo `is_transformed`
# returned whether the first variable was linked. For NTVarInfo we did an OR over the first
# variables under each symbol. We now more consistently use OR, but I'm not convinced this
# is really the right thing to do.
Expand Down Expand Up @@ -1559,9 +1586,15 @@ Set the current value(s) of the random variable `vn` in `vi` to `val`.
The value(s) may or may not be transformed to Euclidean space.
"""
setindex!(vi::VarInfo, val, vn::VarName) = (setval!(vi, val, vn); return vi)

function BangBang.setindex!!(vi::VarInfo, val, vn::VarName)
setindex!(vi, val, vn)
return vi
md = setindex!!(getmetadata(vi, vn), val, vn)
return VarInfo(md, vi.accs)
end

function BangBang.setindex!!(vi::NTVarInfo, val, vn::VarName)
submd = setindex!!(getmetadata(vi, vn), val, vn)
return Accessors.@set vi.metadata[getsym(vn)] = submd
end

@inline function findvns(vi, f_vns)
Expand All @@ -1586,7 +1619,7 @@ function Base.haskey(vi::NTVarInfo, vn::VarName)
return any(md_haskey)
end

function Base.show(io::IO, ::MIME"text/plain", vi::UntypedVarInfo)
function Base.show(io::IO, ::MIME"text/plain", vi::UntypedLegacyVarInfo)
lines = Tuple{String,Any}[
("VarNames", vi.metadata.vns),
("Range", vi.metadata.ranges),
Expand Down Expand Up @@ -1641,7 +1674,7 @@ function _show_varnames(io::IO, vi)
end
end

function Base.show(io::IO, vi::UntypedVarInfo)
function Base.show(io::IO, vi::UntypedLegacyVarInfo)
print(io, "VarInfo (")
_show_varnames(io, vi)
print(io, "; accumulators: ")
Expand Down Expand Up @@ -1813,11 +1846,11 @@ end

values_as(vi::VarInfo) = vi.metadata
values_as(vi::VarInfo, ::Type{Vector}) = copy(getindex_internal(vi, Colon()))
function values_as(vi::UntypedVarInfo, ::Type{NamedTuple})
function values_as(vi::UntypedLegacyVarInfo, ::Type{NamedTuple})
iter = values_from_metadata(vi.metadata)
return NamedTuple(map(p -> Symbol(p.first) => p.second, iter))
end
function values_as(vi::UntypedVarInfo, ::Type{D}) where {D<:AbstractDict}
function values_as(vi::UntypedLegacyVarInfo, ::Type{D}) where {D<:AbstractDict}
return ConstructionBase.constructorof(D)(values_from_metadata(vi.metadata))
end

Expand Down
43 changes: 34 additions & 9 deletions src/varnamedvector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -341,10 +341,13 @@ function ==(vnv_left::VarNamedVector, vnv_right::VarNamedVector)
vnv_left.num_inactive == vnv_right.num_inactive
end

function is_concretely_typed(vnv::VarNamedVector)
return isconcretetype(eltype(vnv.varnames)) &&
isconcretetype(eltype(vnv.vals)) &&
isconcretetype(eltype(vnv.transforms))
function is_tightly_typed(vnv::VarNamedVector)
k = eltype(vnv.varnames)
v = eltype(vnv.vals)
t = eltype(vnv.transforms)
return (isconcretetype(k) || k === Union{}) &&
(isconcretetype(v) || v === Union{}) &&
(isconcretetype(t) || t === Union{})
end

getidx(vnv::VarNamedVector, vn::VarName) = vnv.varname_to_index[vn]
Expand Down Expand Up @@ -880,7 +883,16 @@ function loosen_types!!(
return if vn_type == K && val_type == V && transform_type == T
vnv
elseif isempty(vnv)
VarNamedVector(vn_type[], val_type[], transform_type[])
VarNamedVector(
Dict{vn_type,Int}(),
Vector{vn_type}(),
UnitRange{Int}[],
Vector{val_type}(),
Vector{transform_type}(),
BitVector(),
Dict{Int,Int}();
check_consistency=false,
)
else
# TODO(mhauru) We allow a `vnv` to have any AbstractVector type as its vals, but
# then here always revert to Vector.
Expand Down Expand Up @@ -944,7 +956,7 @@ julia> vnv_tight.transforms
```
"""
function tighten_types!!(vnv::VarNamedVector)
return if is_concretely_typed(vnv)
return if is_tightly_typed(vnv)
# There can not be anything to tighten, so short-circuit.
vnv
elseif isempty(vnv)
Expand Down Expand Up @@ -1020,6 +1032,7 @@ function insert_internal!!(
end
vnv = loosen_types!!(vnv, typeof(vn), eltype(val), typeof(transform))
insert_internal!(vnv, val, vn, transform)
vnv = tighten_types!!(vnv)
return vnv
end

Expand All @@ -1029,6 +1042,7 @@ function update_internal!!(
transform_resolved = transform === nothing ? gettransform(vnv, vn) : transform
vnv = loosen_types!!(vnv, typeof(vn), eltype(val), typeof(transform_resolved))
update_internal!(vnv, val, vn, transform)
vnv = tighten_types!!(vnv)
return vnv
end

Expand Down Expand Up @@ -1104,6 +1118,9 @@ care about them.

This is in a sense the reverse operation of `vnv[:]`.

The return value may share memory with the input `vnv`, and thus one can not be mutated
safely without affecting the other.

Unflatten recontiguifies the internal storage, getting rid of any inactive entries.

# Examples
Expand All @@ -1125,15 +1142,20 @@ function unflatten(vnv::VarNamedVector, vals::AbstractVector)
),
)
end
new_ranges = deepcopy(vnv.ranges)
recontiguify_ranges!(new_ranges)
new_ranges = vnv.ranges
num_inactive = vnv.num_inactive
if has_inactive(vnv)
new_ranges = recontiguify_ranges!(new_ranges)
num_inactive = Dict{Int,Int}()
end
return VarNamedVector(
vnv.varname_to_index,
vnv.varnames,
new_ranges,
vals,
vnv.transforms,
vnv.is_unconstrained;
vnv.is_unconstrained,
num_inactive;
check_consistency=false,
)
end
Expand Down Expand Up @@ -1428,6 +1450,9 @@ julia> vnv[@varname(x)] # All the values are still there.
```
"""
function contiguify!(vnv::VarNamedVector)
if !has_inactive(vnv)
return vnv
end
# Extract the re-contiguified values.
# NOTE: We need to do this before we update the ranges.
old_vals = copy(vnv.vals)
Expand Down
2 changes: 2 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf"
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
Expand Down Expand Up @@ -34,6 +35,7 @@ AbstractMCMC = "5"
AbstractPPL = "0.13"
Accessors = "0.1"
Aqua = "0.8"
BangBang = "0.4"
Bijectors = "0.15.1"
Combinatorics = "1"
DifferentiationInterface = "0.6.41, 0.7"
Expand Down
9 changes: 6 additions & 3 deletions test/contexts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -417,12 +417,15 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown()

@testset "InitContext" begin
empty_varinfos = [
("untyped+metadata", VarInfo()),
("typed+metadata", DynamicPPL.typed_varinfo(VarInfo())),
("untyped+metadata", VarInfo(DynamicPPL.Metadata())),
(
"typed+metadata",
DynamicPPL.typed_legacy_varinfo(VarInfo(DynamicPPL.Metadata())),
),
("untyped+VNV", VarInfo(DynamicPPL.VarNamedVector())),
(
"typed+VNV",
DynamicPPL.typed_vector_varinfo(DynamicPPL.typed_varinfo(VarInfo())),
DynamicPPL.typed_vector_varinfo(VarInfo(DynamicPPL.VarNamedVector())),
),
("SVI+NamedTuple", SimpleVarInfo()),
("Svi+Dict", SimpleVarInfo(Dict{VarName,Any}())),
Expand Down
Loading
Loading