I have not found a way to do this with LangChain, but I found a function that allows me to flatten the output and results in what I want, although it seems a bit clunky and I believe there must be a better solution.
The key is to add the following function to the chain:
def flatten_dict(*vars) -> dict:
'''
Flatten a dictionary by removing unnecessary mid-level keys.
Returns a Runnable (chainable) function.
'''
flat = {}
for var in vars:
keys = [k for k in var]
for key in keys:
if isinstance(var[key], dict):
flat.update(var[key])
else:
flat[key] = var[key]
return flat
chain = (
{"first_prompt_output": first_chain, "possible_values": RunnablePassthrough(), "first_value": RunnablePassthrough()}
| RunnableParallel(result={"second_prompt_output": second_chain, "first_value": itemgetter("first_value")})
)
| flatten_dict